add read me
This commit is contained in:
7
venv/lib/python3.12/site-packages/sklearn/externals/README
vendored
Normal file
7
venv/lib/python3.12/site-packages/sklearn/externals/README
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
This directory contains bundled external dependencies that are updated
|
||||
every once in a while.
|
||||
|
||||
Note for distribution packagers: if you want to remove the duplicated
|
||||
code and depend on a packaged version, we suggest that you simply do a
|
||||
symbolic link in this directory.
|
||||
|
||||
5
venv/lib/python3.12/site-packages/sklearn/externals/__init__.py
vendored
Normal file
5
venv/lib/python3.12/site-packages/sklearn/externals/__init__.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
"""
|
||||
External, bundled dependencies.
|
||||
|
||||
"""
|
||||
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
Binary file not shown.
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/_arff.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/_arff.cpython-312.pyc
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/conftest.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/__pycache__/conftest.cpython-312.pyc
vendored
Normal file
Binary file not shown.
1107
venv/lib/python3.12/site-packages/sklearn/externals/_arff.py
vendored
Normal file
1107
venv/lib/python3.12/site-packages/sklearn/externals/_arff.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
5
venv/lib/python3.12/site-packages/sklearn/externals/_array_api_compat_vendor.py
vendored
Normal file
5
venv/lib/python3.12/site-packages/sklearn/externals/_array_api_compat_vendor.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# DO NOT RENAME THIS FILE
|
||||
# This is a hook for array_api_extra/_lib/_compat.py
|
||||
# to co-vendor array_api_compat and potentially override its functions.
|
||||
|
||||
from .array_api_compat import * # noqa: F403
|
||||
0
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__init__.py
vendored
Normal file
0
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__init__.py
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__pycache__/version.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/__pycache__/version.cpython-312.pyc
vendored
Normal file
Binary file not shown.
90
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/_structures.py
vendored
Normal file
90
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/_structures.py
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Vendoered from
|
||||
https://github.com/pypa/packaging/blob/main/packaging/_structures.py
|
||||
"""
|
||||
# Copyright (c) Donald Stufft and individual contributors.
|
||||
# All rights reserved.
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
class InfinityType:
|
||||
def __repr__(self) -> str:
|
||||
return "Infinity"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
def __neg__(self: object) -> "NegativeInfinityType":
|
||||
return NegativeInfinity
|
||||
|
||||
|
||||
Infinity = InfinityType()
|
||||
|
||||
|
||||
class NegativeInfinityType:
|
||||
def __repr__(self) -> str:
|
||||
return "-Infinity"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __neg__(self: object) -> InfinityType:
|
||||
return Infinity
|
||||
|
||||
|
||||
NegativeInfinity = NegativeInfinityType()
|
||||
535
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/version.py
vendored
Normal file
535
venv/lib/python3.12/site-packages/sklearn/externals/_packaging/version.py
vendored
Normal file
@@ -0,0 +1,535 @@
|
||||
"""Vendoered from
|
||||
https://github.com/pypa/packaging/blob/main/packaging/version.py
|
||||
"""
|
||||
# Copyright (c) Donald Stufft and individual contributors.
|
||||
# All rights reserved.
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import re
|
||||
import warnings
|
||||
from typing import Callable, Iterator, List, Optional, SupportsInt, Tuple, Union
|
||||
|
||||
from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
|
||||
|
||||
__all__ = ["parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"]
|
||||
|
||||
InfiniteTypes = Union[InfinityType, NegativeInfinityType]
|
||||
PrePostDevType = Union[InfiniteTypes, Tuple[str, int]]
|
||||
SubLocalType = Union[InfiniteTypes, int, str]
|
||||
LocalType = Union[
|
||||
NegativeInfinityType,
|
||||
Tuple[
|
||||
Union[
|
||||
SubLocalType,
|
||||
Tuple[SubLocalType, str],
|
||||
Tuple[NegativeInfinityType, SubLocalType],
|
||||
],
|
||||
...,
|
||||
],
|
||||
]
|
||||
CmpKey = Tuple[
|
||||
int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType
|
||||
]
|
||||
LegacyCmpKey = Tuple[int, Tuple[str, ...]]
|
||||
VersionComparisonMethod = Callable[
|
||||
[Union[CmpKey, LegacyCmpKey], Union[CmpKey, LegacyCmpKey]], bool
|
||||
]
|
||||
|
||||
_Version = collections.namedtuple(
|
||||
"_Version", ["epoch", "release", "dev", "pre", "post", "local"]
|
||||
)
|
||||
|
||||
|
||||
def parse(version: str) -> Union["LegacyVersion", "Version"]:
|
||||
"""Parse the given version from a string to an appropriate class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : str
|
||||
Version in a string format, eg. "0.9.1" or "1.2.dev0".
|
||||
|
||||
Returns
|
||||
-------
|
||||
version : :class:`Version` object or a :class:`LegacyVersion` object
|
||||
Returned class depends on the given version: if is a valid
|
||||
PEP 440 version or a legacy version.
|
||||
"""
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion:
|
||||
return LegacyVersion(version)
|
||||
|
||||
|
||||
class InvalidVersion(ValueError):
|
||||
"""
|
||||
An invalid version was found, users should refer to PEP 440.
|
||||
"""
|
||||
|
||||
|
||||
class _BaseVersion:
|
||||
_key: Union[CmpKey, LegacyCmpKey]
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._key)
|
||||
|
||||
# Please keep the duplicated `isinstance` check
|
||||
# in the six comparisons hereunder
|
||||
# unless you find a way to avoid adding overhead function calls.
|
||||
def __lt__(self, other: "_BaseVersion") -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key < other._key
|
||||
|
||||
def __le__(self, other: "_BaseVersion") -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key <= other._key
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key == other._key
|
||||
|
||||
def __ge__(self, other: "_BaseVersion") -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key >= other._key
|
||||
|
||||
def __gt__(self, other: "_BaseVersion") -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key > other._key
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return self._key != other._key
|
||||
|
||||
|
||||
class LegacyVersion(_BaseVersion):
|
||||
def __init__(self, version: str) -> None:
|
||||
self._version = str(version)
|
||||
self._key = _legacy_cmpkey(self._version)
|
||||
|
||||
warnings.warn(
|
||||
"Creating a LegacyVersion has been deprecated and will be "
|
||||
"removed in the next major release",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._version
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LegacyVersion('{self}')>"
|
||||
|
||||
@property
|
||||
def public(self) -> str:
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def base_version(self) -> str:
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def epoch(self) -> int:
|
||||
return -1
|
||||
|
||||
@property
|
||||
def release(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def pre(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def post(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def dev(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def local(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_prerelease(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_postrelease(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_devrelease(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
_legacy_version_component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE)
|
||||
|
||||
_legacy_version_replacement_map = {
|
||||
"pre": "c",
|
||||
"preview": "c",
|
||||
"-": "final-",
|
||||
"rc": "c",
|
||||
"dev": "@",
|
||||
}
|
||||
|
||||
|
||||
def _parse_version_parts(s: str) -> Iterator[str]:
|
||||
for part in _legacy_version_component_re.split(s):
|
||||
part = _legacy_version_replacement_map.get(part, part)
|
||||
|
||||
if not part or part == ".":
|
||||
continue
|
||||
|
||||
if part[:1] in "0123456789":
|
||||
# pad for numeric comparison
|
||||
yield part.zfill(8)
|
||||
else:
|
||||
yield "*" + part
|
||||
|
||||
# ensure that alpha/beta/candidate are before final
|
||||
yield "*final"
|
||||
|
||||
|
||||
def _legacy_cmpkey(version: str) -> LegacyCmpKey:
|
||||
|
||||
# We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch
|
||||
# greater than or equal to 0. This will effectively put the LegacyVersion,
|
||||
# which uses the defacto standard originally implemented by setuptools,
|
||||
# as before all PEP 440 versions.
|
||||
epoch = -1
|
||||
|
||||
# This scheme is taken from pkg_resources.parse_version setuptools prior to
|
||||
# it's adoption of the packaging library.
|
||||
parts: List[str] = []
|
||||
for part in _parse_version_parts(version.lower()):
|
||||
if part.startswith("*"):
|
||||
# remove "-" before a prerelease tag
|
||||
if part < "*final":
|
||||
while parts and parts[-1] == "*final-":
|
||||
parts.pop()
|
||||
|
||||
# remove trailing zeros from each series of numeric parts
|
||||
while parts and parts[-1] == "00000000":
|
||||
parts.pop()
|
||||
|
||||
parts.append(part)
|
||||
|
||||
return epoch, tuple(parts)
|
||||
|
||||
|
||||
# Deliberately not anchored to the start and end of the string, to make it
|
||||
# easier for 3rd party code to reuse
|
||||
VERSION_PATTERN = r"""
|
||||
v?
|
||||
(?:
|
||||
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||
(?P<pre> # pre-release
|
||||
[-_\.]?
|
||||
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||
[-_\.]?
|
||||
(?P<pre_n>[0-9]+)?
|
||||
)?
|
||||
(?P<post> # post release
|
||||
(?:-(?P<post_n1>[0-9]+))
|
||||
|
|
||||
(?:
|
||||
[-_\.]?
|
||||
(?P<post_l>post|rev|r)
|
||||
[-_\.]?
|
||||
(?P<post_n2>[0-9]+)?
|
||||
)
|
||||
)?
|
||||
(?P<dev> # dev release
|
||||
[-_\.]?
|
||||
(?P<dev_l>dev)
|
||||
[-_\.]?
|
||||
(?P<dev_n>[0-9]+)?
|
||||
)?
|
||||
)
|
||||
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||
"""
|
||||
|
||||
|
||||
class Version(_BaseVersion):
|
||||
|
||||
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
||||
|
||||
def __init__(self, version: str) -> None:
|
||||
|
||||
# Validate the version and parse it into pieces
|
||||
match = self._regex.search(version)
|
||||
if not match:
|
||||
raise InvalidVersion(f"Invalid version: '{version}'")
|
||||
|
||||
# Store the parsed out pieces of the version
|
||||
self._version = _Version(
|
||||
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
|
||||
release=tuple(int(i) for i in match.group("release").split(".")),
|
||||
pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
||||
post=_parse_letter_version(
|
||||
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
||||
),
|
||||
dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
||||
local=_parse_local_version(match.group("local")),
|
||||
)
|
||||
|
||||
# Generate a key which will be used for sorting
|
||||
self._key = _cmpkey(
|
||||
self._version.epoch,
|
||||
self._version.release,
|
||||
self._version.pre,
|
||||
self._version.post,
|
||||
self._version.dev,
|
||||
self._version.local,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Version('{self}')>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self.epoch != 0:
|
||||
parts.append(f"{self.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self.release))
|
||||
|
||||
# Pre-release
|
||||
if self.pre is not None:
|
||||
parts.append("".join(str(x) for x in self.pre))
|
||||
|
||||
# Post-release
|
||||
if self.post is not None:
|
||||
parts.append(f".post{self.post}")
|
||||
|
||||
# Development release
|
||||
if self.dev is not None:
|
||||
parts.append(f".dev{self.dev}")
|
||||
|
||||
# Local version segment
|
||||
if self.local is not None:
|
||||
parts.append(f"+{self.local}")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def epoch(self) -> int:
|
||||
_epoch: int = self._version.epoch
|
||||
return _epoch
|
||||
|
||||
@property
|
||||
def release(self) -> Tuple[int, ...]:
|
||||
_release: Tuple[int, ...] = self._version.release
|
||||
return _release
|
||||
|
||||
@property
|
||||
def pre(self) -> Optional[Tuple[str, int]]:
|
||||
_pre: Optional[Tuple[str, int]] = self._version.pre
|
||||
return _pre
|
||||
|
||||
@property
|
||||
def post(self) -> Optional[int]:
|
||||
return self._version.post[1] if self._version.post else None
|
||||
|
||||
@property
|
||||
def dev(self) -> Optional[int]:
|
||||
return self._version.dev[1] if self._version.dev else None
|
||||
|
||||
@property
|
||||
def local(self) -> Optional[str]:
|
||||
if self._version.local:
|
||||
return ".".join(str(x) for x in self._version.local)
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def public(self) -> str:
|
||||
return str(self).split("+", 1)[0]
|
||||
|
||||
@property
|
||||
def base_version(self) -> str:
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self.epoch != 0:
|
||||
parts.append(f"{self.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self.release))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def is_prerelease(self) -> bool:
|
||||
return self.dev is not None or self.pre is not None
|
||||
|
||||
@property
|
||||
def is_postrelease(self) -> bool:
|
||||
return self.post is not None
|
||||
|
||||
@property
|
||||
def is_devrelease(self) -> bool:
|
||||
return self.dev is not None
|
||||
|
||||
@property
|
||||
def major(self) -> int:
|
||||
return self.release[0] if len(self.release) >= 1 else 0
|
||||
|
||||
@property
|
||||
def minor(self) -> int:
|
||||
return self.release[1] if len(self.release) >= 2 else 0
|
||||
|
||||
@property
|
||||
def micro(self) -> int:
|
||||
return self.release[2] if len(self.release) >= 3 else 0
|
||||
|
||||
|
||||
def _parse_letter_version(
|
||||
letter: str, number: Union[str, bytes, SupportsInt]
|
||||
) -> Optional[Tuple[str, int]]:
|
||||
|
||||
if letter:
|
||||
# We consider there to be an implicit 0 in a pre-release if there is
|
||||
# not a numeral associated with it.
|
||||
if number is None:
|
||||
number = 0
|
||||
|
||||
# We normalize any letters to their lower case form
|
||||
letter = letter.lower()
|
||||
|
||||
# We consider some words to be alternate spellings of other words and
|
||||
# in those cases we want to normalize the spellings to our preferred
|
||||
# spelling.
|
||||
if letter == "alpha":
|
||||
letter = "a"
|
||||
elif letter == "beta":
|
||||
letter = "b"
|
||||
elif letter in ["c", "pre", "preview"]:
|
||||
letter = "rc"
|
||||
elif letter in ["rev", "r"]:
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
if not letter and number:
|
||||
# We assume if we are given a number, but we are not given a letter
|
||||
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_local_version_separators = re.compile(r"[\._-]")
|
||||
|
||||
|
||||
def _parse_local_version(local: str) -> Optional[LocalType]:
|
||||
"""
|
||||
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||
"""
|
||||
if local is not None:
|
||||
return tuple(
|
||||
part.lower() if not part.isdigit() else int(part)
|
||||
for part in _local_version_separators.split(local)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _cmpkey(
|
||||
epoch: int,
|
||||
release: Tuple[int, ...],
|
||||
pre: Optional[Tuple[str, int]],
|
||||
post: Optional[Tuple[str, int]],
|
||||
dev: Optional[Tuple[str, int]],
|
||||
local: Optional[Tuple[SubLocalType]],
|
||||
) -> CmpKey:
|
||||
|
||||
# When we compare a release version, we want to compare it with all of the
|
||||
# trailing zeros removed. So we'll use a reverse the list, drop all the now
|
||||
# leading zeros until we come to something non zero, then take the rest
|
||||
# re-reverse it back into the correct order and make it a tuple and use
|
||||
# that for our sorting key.
|
||||
_release = tuple(
|
||||
reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
|
||||
)
|
||||
|
||||
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
|
||||
# We'll do this by abusing the pre segment, but we _only_ want to do this
|
||||
# if there is not a pre or a post segment. If we have one of those then
|
||||
# the normal sorting rules will handle this case correctly.
|
||||
if pre is None and post is None and dev is not None:
|
||||
_pre: PrePostDevType = NegativeInfinity
|
||||
# Versions without a pre-release (except as noted above) should sort after
|
||||
# those with one.
|
||||
elif pre is None:
|
||||
_pre = Infinity
|
||||
else:
|
||||
_pre = pre
|
||||
|
||||
# Versions without a post segment should sort before those with one.
|
||||
if post is None:
|
||||
_post: PrePostDevType = NegativeInfinity
|
||||
|
||||
else:
|
||||
_post = post
|
||||
|
||||
# Versions without a development segment should sort after those with one.
|
||||
if dev is None:
|
||||
_dev: PrePostDevType = Infinity
|
||||
|
||||
else:
|
||||
_dev = dev
|
||||
|
||||
if local is None:
|
||||
# Versions without a local segment should sort before those with one.
|
||||
_local: LocalType = NegativeInfinity
|
||||
else:
|
||||
# Versions with a local segment need that segment parsed to implement
|
||||
# the sorting rules in PEP440.
|
||||
# - Alpha numeric segments sort before numeric segments
|
||||
# - Alpha numeric segments sort lexicographically
|
||||
# - Numeric segments sort numerically
|
||||
# - Shorter versions sort before longer versions when the prefixes
|
||||
# match exactly
|
||||
_local = tuple(
|
||||
(i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
|
||||
)
|
||||
|
||||
return epoch, _release, _pre, _post, _dev, _local
|
||||
0
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/__init__.py
vendored
Normal file
0
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/__init__.py
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
BIN
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/__pycache__/__init__.cpython-312.pyc
vendored
Normal file
Binary file not shown.
0
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/__init__.py
vendored
Normal file
0
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/__init__.py
vendored
Normal file
Binary file not shown.
1
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/csgraph/__init__.py
vendored
Normal file
1
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/csgraph/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from ._laplacian import laplacian
|
||||
Binary file not shown.
Binary file not shown.
557
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/csgraph/_laplacian.py
vendored
Normal file
557
venv/lib/python3.12/site-packages/sklearn/externals/_scipy/sparse/csgraph/_laplacian.py
vendored
Normal file
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
This file is a copy of the scipy.sparse.csgraph._laplacian module from SciPy 1.12
|
||||
|
||||
scipy.sparse.csgraph.laplacian supports sparse arrays only starting from Scipy 1.12,
|
||||
see https://github.com/scipy/scipy/pull/19156. This vendored file can be removed as
|
||||
soon as Scipy 1.12 becomes the minimum supported version.
|
||||
|
||||
Laplacian of a compressed-sparse graph
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse import issparse
|
||||
from scipy.sparse.linalg import LinearOperator
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Graph laplacian
|
||||
def laplacian(
|
||||
csgraph,
|
||||
normed=False,
|
||||
return_diag=False,
|
||||
use_out_degree=False,
|
||||
*,
|
||||
copy=True,
|
||||
form="array",
|
||||
dtype=None,
|
||||
symmetrized=False,
|
||||
):
|
||||
"""
|
||||
Return the Laplacian of a directed graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
csgraph : array_like or sparse matrix, 2 dimensions
|
||||
Compressed-sparse graph, with shape (N, N).
|
||||
normed : bool, optional
|
||||
If True, then compute symmetrically normalized Laplacian.
|
||||
Default: False.
|
||||
return_diag : bool, optional
|
||||
If True, then also return an array related to vertex degrees.
|
||||
Default: False.
|
||||
use_out_degree : bool, optional
|
||||
If True, then use out-degree instead of in-degree.
|
||||
This distinction matters only if the graph is asymmetric.
|
||||
Default: False.
|
||||
copy : bool, optional
|
||||
If False, then change `csgraph` in place if possible,
|
||||
avoiding doubling the memory use.
|
||||
Default: True, for backward compatibility.
|
||||
form : 'array', or 'function', or 'lo'
|
||||
Determines the format of the output Laplacian:
|
||||
|
||||
* 'array' is a numpy array;
|
||||
* 'function' is a pointer to evaluating the Laplacian-vector
|
||||
or Laplacian-matrix product;
|
||||
* 'lo' results in the format of the `LinearOperator`.
|
||||
|
||||
Choosing 'function' or 'lo' always avoids doubling
|
||||
the memory use, ignoring `copy` value.
|
||||
Default: 'array', for backward compatibility.
|
||||
dtype : None or one of numeric numpy dtypes, optional
|
||||
The dtype of the output. If ``dtype=None``, the dtype of the
|
||||
output matches the dtype of the input csgraph, except for
|
||||
the case ``normed=True`` and integer-like csgraph, where
|
||||
the output dtype is 'float' allowing accurate normalization,
|
||||
but dramatically increasing the memory use.
|
||||
Default: None, for backward compatibility.
|
||||
symmetrized : bool, optional
|
||||
If True, then the output Laplacian is symmetric/Hermitian.
|
||||
The symmetrization is done by ``csgraph + csgraph.T.conj``
|
||||
without dividing by 2 to preserve integer dtypes if possible
|
||||
prior to the construction of the Laplacian.
|
||||
The symmetrization will increase the memory footprint of
|
||||
sparse matrices unless the sparsity pattern is symmetric or
|
||||
`form` is 'function' or 'lo'.
|
||||
Default: False, for backward compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lap : ndarray, or sparse matrix, or `LinearOperator`
|
||||
The N x N Laplacian of csgraph. It will be a NumPy array (dense)
|
||||
if the input was dense, or a sparse matrix otherwise, or
|
||||
the format of a function or `LinearOperator` if
|
||||
`form` equals 'function' or 'lo', respectively.
|
||||
diag : ndarray, optional
|
||||
The length-N main diagonal of the Laplacian matrix.
|
||||
For the normalized Laplacian, this is the array of square roots
|
||||
of vertex degrees or 1 if the degree is zero.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The Laplacian matrix of a graph is sometimes referred to as the
|
||||
"Kirchhoff matrix" or just the "Laplacian", and is useful in many
|
||||
parts of spectral graph theory.
|
||||
In particular, the eigen-decomposition of the Laplacian can give
|
||||
insight into many properties of the graph, e.g.,
|
||||
is commonly used for spectral data embedding and clustering.
|
||||
|
||||
The constructed Laplacian doubles the memory use if ``copy=True`` and
|
||||
``form="array"`` which is the default.
|
||||
Choosing ``copy=False`` has no effect unless ``form="array"``
|
||||
or the matrix is sparse in the ``coo`` format, or dense array, except
|
||||
for the integer input with ``normed=True`` that forces the float output.
|
||||
|
||||
Sparse input is reformatted into ``coo`` if ``form="array"``,
|
||||
which is the default.
|
||||
|
||||
If the input adjacency matrix is not symmetric, the Laplacian is
|
||||
also non-symmetric unless ``symmetrized=True`` is used.
|
||||
|
||||
Diagonal entries of the input adjacency matrix are ignored and
|
||||
replaced with zeros for the purpose of normalization where ``normed=True``.
|
||||
The normalization uses the inverse square roots of row-sums of the input
|
||||
adjacency matrix, and thus may fail if the row-sums contain
|
||||
negative or complex with a non-zero imaginary part values.
|
||||
|
||||
The normalization is symmetric, making the normalized Laplacian also
|
||||
symmetric if the input csgraph was symmetric.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Laplacian matrix. https://en.wikipedia.org/wiki/Laplacian_matrix
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy.sparse import csgraph
|
||||
|
||||
Our first illustration is the symmetric graph
|
||||
|
||||
>>> G = np.arange(4) * np.arange(4)[:, np.newaxis]
|
||||
>>> G
|
||||
array([[0, 0, 0, 0],
|
||||
[0, 1, 2, 3],
|
||||
[0, 2, 4, 6],
|
||||
[0, 3, 6, 9]])
|
||||
|
||||
and its symmetric Laplacian matrix
|
||||
|
||||
>>> csgraph.laplacian(G)
|
||||
array([[ 0, 0, 0, 0],
|
||||
[ 0, 5, -2, -3],
|
||||
[ 0, -2, 8, -6],
|
||||
[ 0, -3, -6, 9]])
|
||||
|
||||
The non-symmetric graph
|
||||
|
||||
>>> G = np.arange(9).reshape(3, 3)
|
||||
>>> G
|
||||
array([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]])
|
||||
|
||||
has different row- and column sums, resulting in two varieties
|
||||
of the Laplacian matrix, using an in-degree, which is the default
|
||||
|
||||
>>> L_in_degree = csgraph.laplacian(G)
|
||||
>>> L_in_degree
|
||||
array([[ 9, -1, -2],
|
||||
[-3, 8, -5],
|
||||
[-6, -7, 7]])
|
||||
|
||||
or alternatively an out-degree
|
||||
|
||||
>>> L_out_degree = csgraph.laplacian(G, use_out_degree=True)
|
||||
>>> L_out_degree
|
||||
array([[ 3, -1, -2],
|
||||
[-3, 8, -5],
|
||||
[-6, -7, 13]])
|
||||
|
||||
Constructing a symmetric Laplacian matrix, one can add the two as
|
||||
|
||||
>>> L_in_degree + L_out_degree.T
|
||||
array([[ 12, -4, -8],
|
||||
[ -4, 16, -12],
|
||||
[ -8, -12, 20]])
|
||||
|
||||
or use the ``symmetrized=True`` option
|
||||
|
||||
>>> csgraph.laplacian(G, symmetrized=True)
|
||||
array([[ 12, -4, -8],
|
||||
[ -4, 16, -12],
|
||||
[ -8, -12, 20]])
|
||||
|
||||
that is equivalent to symmetrizing the original graph
|
||||
|
||||
>>> csgraph.laplacian(G + G.T)
|
||||
array([[ 12, -4, -8],
|
||||
[ -4, 16, -12],
|
||||
[ -8, -12, 20]])
|
||||
|
||||
The goal of normalization is to make the non-zero diagonal entries
|
||||
of the Laplacian matrix to be all unit, also scaling off-diagonal
|
||||
entries correspondingly. The normalization can be done manually, e.g.,
|
||||
|
||||
>>> G = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
|
||||
>>> L, d = csgraph.laplacian(G, return_diag=True)
|
||||
>>> L
|
||||
array([[ 2, -1, -1],
|
||||
[-1, 2, -1],
|
||||
[-1, -1, 2]])
|
||||
>>> d
|
||||
array([2, 2, 2])
|
||||
>>> scaling = np.sqrt(d)
|
||||
>>> scaling
|
||||
array([1.41421356, 1.41421356, 1.41421356])
|
||||
>>> (1/scaling)*L*(1/scaling)
|
||||
array([[ 1. , -0.5, -0.5],
|
||||
[-0.5, 1. , -0.5],
|
||||
[-0.5, -0.5, 1. ]])
|
||||
|
||||
Or using ``normed=True`` option
|
||||
|
||||
>>> L, d = csgraph.laplacian(G, return_diag=True, normed=True)
|
||||
>>> L
|
||||
array([[ 1. , -0.5, -0.5],
|
||||
[-0.5, 1. , -0.5],
|
||||
[-0.5, -0.5, 1. ]])
|
||||
|
||||
which now instead of the diagonal returns the scaling coefficients
|
||||
|
||||
>>> d
|
||||
array([1.41421356, 1.41421356, 1.41421356])
|
||||
|
||||
Zero scaling coefficients are substituted with 1s, where scaling
|
||||
has thus no effect, e.g.,
|
||||
|
||||
>>> G = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0]])
|
||||
>>> G
|
||||
array([[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0]])
|
||||
>>> L, d = csgraph.laplacian(G, return_diag=True, normed=True)
|
||||
>>> L
|
||||
array([[ 0., -0., -0.],
|
||||
[-0., 1., -1.],
|
||||
[-0., -1., 1.]])
|
||||
>>> d
|
||||
array([1., 1., 1.])
|
||||
|
||||
Only the symmetric normalization is implemented, resulting
|
||||
in a symmetric Laplacian matrix if and only if its graph is symmetric
|
||||
and has all non-negative degrees, like in the examples above.
|
||||
|
||||
The output Laplacian matrix is by default a dense array or a sparse matrix
|
||||
inferring its shape, format, and dtype from the input graph matrix:
|
||||
|
||||
>>> G = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype(np.float32)
|
||||
>>> G
|
||||
array([[0., 1., 1.],
|
||||
[1., 0., 1.],
|
||||
[1., 1., 0.]], dtype=float32)
|
||||
>>> csgraph.laplacian(G)
|
||||
array([[ 2., -1., -1.],
|
||||
[-1., 2., -1.],
|
||||
[-1., -1., 2.]], dtype=float32)
|
||||
|
||||
but can alternatively be generated matrix-free as a LinearOperator:
|
||||
|
||||
>>> L = csgraph.laplacian(G, form="lo")
|
||||
>>> L
|
||||
<3x3 _CustomLinearOperator with dtype=float32>
|
||||
>>> L(np.eye(3))
|
||||
array([[ 2., -1., -1.],
|
||||
[-1., 2., -1.],
|
||||
[-1., -1., 2.]])
|
||||
|
||||
or as a lambda-function:
|
||||
|
||||
>>> L = csgraph.laplacian(G, form="function")
|
||||
>>> L
|
||||
<function _laplace.<locals>.<lambda> at 0x0000012AE6F5A598>
|
||||
>>> L(np.eye(3))
|
||||
array([[ 2., -1., -1.],
|
||||
[-1., 2., -1.],
|
||||
[-1., -1., 2.]])
|
||||
|
||||
The Laplacian matrix is used for
|
||||
spectral data clustering and embedding
|
||||
as well as for spectral graph partitioning.
|
||||
Our final example illustrates the latter
|
||||
for a noisy directed linear graph.
|
||||
|
||||
>>> from scipy.sparse import diags, random
|
||||
>>> from scipy.sparse.linalg import lobpcg
|
||||
|
||||
Create a directed linear graph with ``N=35`` vertices
|
||||
using a sparse adjacency matrix ``G``:
|
||||
|
||||
>>> N = 35
|
||||
>>> G = diags(np.ones(N-1), 1, format="csr")
|
||||
|
||||
Fix a random seed ``rng`` and add a random sparse noise to the graph ``G``:
|
||||
|
||||
>>> rng = np.random.default_rng()
|
||||
>>> G += 1e-2 * random(N, N, density=0.1, random_state=rng)
|
||||
|
||||
Set initial approximations for eigenvectors:
|
||||
|
||||
>>> X = rng.random((N, 2))
|
||||
|
||||
The constant vector of ones is always a trivial eigenvector
|
||||
of the non-normalized Laplacian to be filtered out:
|
||||
|
||||
>>> Y = np.ones((N, 1))
|
||||
|
||||
Alternating (1) the sign of the graph weights allows determining
|
||||
labels for spectral max- and min- cuts in a single loop.
|
||||
Since the graph is undirected, the option ``symmetrized=True``
|
||||
must be used in the construction of the Laplacian.
|
||||
The option ``normed=True`` cannot be used in (2) for the negative weights
|
||||
here as the symmetric normalization evaluates square roots.
|
||||
The option ``form="lo"`` in (2) is matrix-free, i.e., guarantees
|
||||
a fixed memory footprint and read-only access to the graph.
|
||||
Calling the eigenvalue solver ``lobpcg`` (3) computes the Fiedler vector
|
||||
that determines the labels as the signs of its components in (5).
|
||||
Since the sign in an eigenvector is not deterministic and can flip,
|
||||
we fix the sign of the first component to be always +1 in (4).
|
||||
|
||||
>>> for cut in ["max", "min"]:
|
||||
... G = -G # 1.
|
||||
... L = csgraph.laplacian(G, symmetrized=True, form="lo") # 2.
|
||||
... _, eves = lobpcg(L, X, Y=Y, largest=False, tol=1e-3) # 3.
|
||||
... eves *= np.sign(eves[0, 0]) # 4.
|
||||
... print(cut + "-cut labels:\\n", 1 * (eves[:, 0]>0)) # 5.
|
||||
max-cut labels:
|
||||
[1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1]
|
||||
min-cut labels:
|
||||
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
|
||||
As anticipated for a (slightly noisy) linear graph,
|
||||
the max-cut strips all the edges of the graph coloring all
|
||||
odd vertices into one color and all even vertices into another one,
|
||||
while the balanced min-cut partitions the graph
|
||||
in the middle by deleting a single edge.
|
||||
Both determined partitions are optimal.
|
||||
"""
|
||||
if csgraph.ndim != 2 or csgraph.shape[0] != csgraph.shape[1]:
|
||||
raise ValueError("csgraph must be a square matrix or array")
|
||||
|
||||
if normed and (
|
||||
np.issubdtype(csgraph.dtype, np.signedinteger)
|
||||
or np.issubdtype(csgraph.dtype, np.uint)
|
||||
):
|
||||
csgraph = csgraph.astype(np.float64)
|
||||
|
||||
if form == "array":
|
||||
create_lap = _laplacian_sparse if issparse(csgraph) else _laplacian_dense
|
||||
else:
|
||||
create_lap = (
|
||||
_laplacian_sparse_flo if issparse(csgraph) else _laplacian_dense_flo
|
||||
)
|
||||
|
||||
degree_axis = 1 if use_out_degree else 0
|
||||
|
||||
lap, d = create_lap(
|
||||
csgraph,
|
||||
normed=normed,
|
||||
axis=degree_axis,
|
||||
copy=copy,
|
||||
form=form,
|
||||
dtype=dtype,
|
||||
symmetrized=symmetrized,
|
||||
)
|
||||
if return_diag:
|
||||
return lap, d
|
||||
return lap
|
||||
|
||||
|
||||
def _setdiag_dense(m, d):
|
||||
step = len(d) + 1
|
||||
m.flat[::step] = d
|
||||
|
||||
|
||||
def _laplace(m, d):
|
||||
return lambda v: v * d[:, np.newaxis] - m @ v
|
||||
|
||||
|
||||
def _laplace_normed(m, d, nd):
|
||||
laplace = _laplace(m, d)
|
||||
return lambda v: nd[:, np.newaxis] * laplace(v * nd[:, np.newaxis])
|
||||
|
||||
|
||||
def _laplace_sym(m, d):
|
||||
return (
|
||||
lambda v: v * d[:, np.newaxis]
|
||||
- m @ v
|
||||
- np.transpose(np.conjugate(np.transpose(np.conjugate(v)) @ m))
|
||||
)
|
||||
|
||||
|
||||
def _laplace_normed_sym(m, d, nd):
|
||||
laplace_sym = _laplace_sym(m, d)
|
||||
return lambda v: nd[:, np.newaxis] * laplace_sym(v * nd[:, np.newaxis])
|
||||
|
||||
|
||||
def _linearoperator(mv, shape, dtype):
|
||||
return LinearOperator(matvec=mv, matmat=mv, shape=shape, dtype=dtype)
|
||||
|
||||
|
||||
def _laplacian_sparse_flo(graph, normed, axis, copy, form, dtype, symmetrized):
|
||||
# The keyword argument `copy` is unused and has no effect here.
|
||||
del copy
|
||||
|
||||
if dtype is None:
|
||||
dtype = graph.dtype
|
||||
|
||||
graph_sum = np.asarray(graph.sum(axis=axis)).ravel()
|
||||
graph_diagonal = graph.diagonal()
|
||||
diag = graph_sum - graph_diagonal
|
||||
if symmetrized:
|
||||
graph_sum += np.asarray(graph.sum(axis=1 - axis)).ravel()
|
||||
diag = graph_sum - graph_diagonal - graph_diagonal
|
||||
|
||||
if normed:
|
||||
isolated_node_mask = diag == 0
|
||||
w = np.where(isolated_node_mask, 1, np.sqrt(diag))
|
||||
if symmetrized:
|
||||
md = _laplace_normed_sym(graph, graph_sum, 1.0 / w)
|
||||
else:
|
||||
md = _laplace_normed(graph, graph_sum, 1.0 / w)
|
||||
if form == "function":
|
||||
return md, w.astype(dtype, copy=False)
|
||||
elif form == "lo":
|
||||
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
|
||||
return m, w.astype(dtype, copy=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid form: {form!r}")
|
||||
else:
|
||||
if symmetrized:
|
||||
md = _laplace_sym(graph, graph_sum)
|
||||
else:
|
||||
md = _laplace(graph, graph_sum)
|
||||
if form == "function":
|
||||
return md, diag.astype(dtype, copy=False)
|
||||
elif form == "lo":
|
||||
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
|
||||
return m, diag.astype(dtype, copy=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid form: {form!r}")
|
||||
|
||||
|
||||
def _laplacian_sparse(graph, normed, axis, copy, form, dtype, symmetrized):
|
||||
# The keyword argument `form` is unused and has no effect here.
|
||||
del form
|
||||
|
||||
if dtype is None:
|
||||
dtype = graph.dtype
|
||||
|
||||
needs_copy = False
|
||||
if graph.format in ("lil", "dok"):
|
||||
m = graph.tocoo()
|
||||
else:
|
||||
m = graph
|
||||
if copy:
|
||||
needs_copy = True
|
||||
|
||||
if symmetrized:
|
||||
m += m.T.conj()
|
||||
|
||||
w = np.asarray(m.sum(axis=axis)).ravel() - m.diagonal()
|
||||
if normed:
|
||||
m = m.tocoo(copy=needs_copy)
|
||||
isolated_node_mask = w == 0
|
||||
w = np.where(isolated_node_mask, 1, np.sqrt(w))
|
||||
m.data /= w[m.row]
|
||||
m.data /= w[m.col]
|
||||
m.data *= -1
|
||||
m.setdiag(1 - isolated_node_mask)
|
||||
else:
|
||||
if m.format == "dia":
|
||||
m = m.copy()
|
||||
else:
|
||||
m = m.tocoo(copy=needs_copy)
|
||||
m.data *= -1
|
||||
m.setdiag(w)
|
||||
|
||||
return m.astype(dtype, copy=False), w.astype(dtype)
|
||||
|
||||
|
||||
def _laplacian_dense_flo(graph, normed, axis, copy, form, dtype, symmetrized):
|
||||
if copy:
|
||||
m = np.array(graph)
|
||||
else:
|
||||
m = np.asarray(graph)
|
||||
|
||||
if dtype is None:
|
||||
dtype = m.dtype
|
||||
|
||||
graph_sum = m.sum(axis=axis)
|
||||
graph_diagonal = m.diagonal()
|
||||
diag = graph_sum - graph_diagonal
|
||||
if symmetrized:
|
||||
graph_sum += m.sum(axis=1 - axis)
|
||||
diag = graph_sum - graph_diagonal - graph_diagonal
|
||||
|
||||
if normed:
|
||||
isolated_node_mask = diag == 0
|
||||
w = np.where(isolated_node_mask, 1, np.sqrt(diag))
|
||||
if symmetrized:
|
||||
md = _laplace_normed_sym(m, graph_sum, 1.0 / w)
|
||||
else:
|
||||
md = _laplace_normed(m, graph_sum, 1.0 / w)
|
||||
if form == "function":
|
||||
return md, w.astype(dtype, copy=False)
|
||||
elif form == "lo":
|
||||
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
|
||||
return m, w.astype(dtype, copy=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid form: {form!r}")
|
||||
else:
|
||||
if symmetrized:
|
||||
md = _laplace_sym(m, graph_sum)
|
||||
else:
|
||||
md = _laplace(m, graph_sum)
|
||||
if form == "function":
|
||||
return md, diag.astype(dtype, copy=False)
|
||||
elif form == "lo":
|
||||
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
|
||||
return m, diag.astype(dtype, copy=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid form: {form!r}")
|
||||
|
||||
|
||||
def _laplacian_dense(graph, normed, axis, copy, form, dtype, symmetrized):
|
||||
if form != "array":
|
||||
raise ValueError(f'{form!r} must be "array"')
|
||||
|
||||
if dtype is None:
|
||||
dtype = graph.dtype
|
||||
|
||||
if copy:
|
||||
m = np.array(graph)
|
||||
else:
|
||||
m = np.asarray(graph)
|
||||
|
||||
if dtype is None:
|
||||
dtype = m.dtype
|
||||
|
||||
if symmetrized:
|
||||
m += m.T.conj()
|
||||
np.fill_diagonal(m, 0)
|
||||
w = m.sum(axis=axis)
|
||||
if normed:
|
||||
isolated_node_mask = w == 0
|
||||
w = np.where(isolated_node_mask, 1, np.sqrt(w))
|
||||
m /= w
|
||||
m /= w[:, np.newaxis]
|
||||
m *= -1
|
||||
_setdiag_dense(m, 1 - isolated_node_mask)
|
||||
else:
|
||||
m *= -1
|
||||
_setdiag_dense(m, w)
|
||||
|
||||
return m.astype(dtype, copy=False), w.astype(dtype, copy=False)
|
||||
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/LICENSE
vendored
Normal file
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/LICENSE
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Consortium for Python Data API Standards
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/README.md
vendored
Normal file
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/README.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Update this directory using maint_tools/vendor_array_api_compat.sh
|
||||
22
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/__init__.py
vendored
Normal file
22
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/__init__.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
NumPy Array API compatibility library
|
||||
|
||||
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
|
||||
compatible with the Array API standard https://data-apis.org/array-api/latest/.
|
||||
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
|
||||
|
||||
Unlike array_api_strict, this is not a strict minimal implementation of the
|
||||
Array API, but rather just an extension of the main NumPy namespace with
|
||||
changes needed to be compliant with the Array API. See
|
||||
https://numpy.org/doc/stable/reference/array_api.html for a full list of
|
||||
changes. In particular, unlike array_api_strict, this package does not use a
|
||||
separate Array object, but rather just uses numpy.ndarray directly.
|
||||
|
||||
Library authors using the Array API may wish to test against array_api_strict
|
||||
to ensure they are not using functionality outside of the standard, but prefer
|
||||
this implementation for the default when working with NumPy arrays.
|
||||
|
||||
"""
|
||||
__version__ = '1.12.0'
|
||||
|
||||
from .common import * # noqa: F401, F403
|
||||
Binary file not shown.
Binary file not shown.
59
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/_internal.py
vendored
Normal file
59
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/_internal.py
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Internal helpers
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from types import ModuleType
|
||||
from typing import TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
|
||||
"""
|
||||
Decorator to automatically replace xp with the corresponding array module.
|
||||
|
||||
Use like
|
||||
|
||||
import numpy as np
|
||||
|
||||
@get_xp(np)
|
||||
def func(x, /, xp, kwarg=None):
|
||||
return xp.func(x, kwarg=kwarg)
|
||||
|
||||
Note that xp must be a keyword argument and come after all non-keyword
|
||||
arguments.
|
||||
|
||||
"""
|
||||
|
||||
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
|
||||
@wraps(f)
|
||||
def wrapped_f(*args: object, **kwargs: object) -> object:
|
||||
return f(*args, xp=xp, **kwargs)
|
||||
|
||||
sig = signature(f)
|
||||
new_sig = sig.replace(
|
||||
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
|
||||
)
|
||||
|
||||
if wrapped_f.__doc__ is None:
|
||||
wrapped_f.__doc__ = f"""\
|
||||
Array API compatibility wrapper for {f.__name__}.
|
||||
|
||||
See the corresponding documentation in NumPy/CuPy and/or the array API
|
||||
specification for more details.
|
||||
|
||||
"""
|
||||
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
|
||||
return wrapped_f # pyright: ignore[reportReturnType]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
__all__ = ["get_xp"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/__init__.py
vendored
Normal file
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from ._helpers import * # noqa: F403
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
727
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_aliases.py
vendored
Normal file
727
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_aliases.py
vendored
Normal file
@@ -0,0 +1,727 @@
|
||||
"""
|
||||
These are functions that are just aliases of existing functions in NumPy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
|
||||
|
||||
from ._helpers import _check_device, array_namespace
|
||||
from ._helpers import device as _get_device
|
||||
from ._helpers import is_cupy_namespace as _is_cupy_namespace
|
||||
from ._typing import Array, Device, DType, Namespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TODO: import from typing (requires Python >=3.13)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
# These functions are modified from the NumPy versions.
|
||||
|
||||
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
|
||||
|
||||
|
||||
def arange(
|
||||
start: float,
|
||||
/,
|
||||
stop: float | None = None,
|
||||
step: float = 1,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def empty(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.empty(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def empty_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.empty_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def eye(
|
||||
n_rows: int,
|
||||
n_cols: int | None = None,
|
||||
/,
|
||||
*,
|
||||
xp: Namespace,
|
||||
k: int = 0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def full(
|
||||
shape: int | tuple[int, ...],
|
||||
fill_value: complex,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def full_like(
|
||||
x: Array,
|
||||
/,
|
||||
fill_value: complex,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def linspace(
|
||||
start: float,
|
||||
stop: float,
|
||||
/,
|
||||
num: int,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
endpoint: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
|
||||
|
||||
|
||||
def ones(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.ones(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def ones_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.ones_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def zeros(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def zeros_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
# np.unique() is split into four functions in the array API:
|
||||
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
|
||||
# to remove polymorphic return types).
|
||||
|
||||
# The functions here return namedtuples (np.unique() returns a normal
|
||||
# tuple).
|
||||
|
||||
|
||||
# Note that these named tuples aren't actually part of the standard namespace,
|
||||
# but I don't see any issue with exporting the names here regardless.
|
||||
class UniqueAllResult(NamedTuple):
|
||||
values: Array
|
||||
indices: Array
|
||||
inverse_indices: Array
|
||||
counts: Array
|
||||
|
||||
|
||||
class UniqueCountsResult(NamedTuple):
|
||||
values: Array
|
||||
counts: Array
|
||||
|
||||
|
||||
class UniqueInverseResult(NamedTuple):
|
||||
values: Array
|
||||
inverse_indices: Array
|
||||
|
||||
|
||||
def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
|
||||
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
|
||||
# trying to parse version numbers, just check if equal_nan is in the
|
||||
# signature.
|
||||
s = inspect.signature(xp.unique)
|
||||
if "equal_nan" in s.parameters:
|
||||
return {"equal_nan": False}
|
||||
return {}
|
||||
|
||||
|
||||
def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, indices, inverse_indices, counts = xp.unique(
|
||||
x,
|
||||
return_counts=True,
|
||||
return_index=True,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# np.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueAllResult(
|
||||
values,
|
||||
indices,
|
||||
inverse_indices,
|
||||
counts,
|
||||
)
|
||||
|
||||
|
||||
def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
res = xp.unique(
|
||||
x, return_counts=True, return_index=False, return_inverse=False, **kwargs
|
||||
)
|
||||
|
||||
return UniqueCountsResult(*res)
|
||||
|
||||
|
||||
def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, inverse_indices = xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# xp.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueInverseResult(values, inverse_indices)
|
||||
|
||||
|
||||
def unique_values(x: Array, /, xp: Namespace) -> Array:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
return xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# These functions have different keyword argument names
|
||||
|
||||
|
||||
def std(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
correction: float = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
def var(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
correction: float = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
|
||||
# argument
|
||||
|
||||
|
||||
def cumulative_sum(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | None = None,
|
||||
dtype: DType | None = None,
|
||||
include_initial: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
# TODO: The standard is not clear about what should happen when x.ndim == 0.
|
||||
if axis is None:
|
||||
if x.ndim > 1:
|
||||
raise ValueError(
|
||||
"axis must be specified in cumulative_sum for more than one dimension"
|
||||
)
|
||||
axis = 0
|
||||
|
||||
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
|
||||
|
||||
# np.cumsum does not support include_initial
|
||||
if include_initial:
|
||||
initial_shape = list(x.shape)
|
||||
initial_shape[axis] = 1
|
||||
res = xp.concatenate(
|
||||
[
|
||||
wrapped_xp.zeros(
|
||||
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
|
||||
),
|
||||
res,
|
||||
],
|
||||
axis=axis,
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def cumulative_prod(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | None = None,
|
||||
dtype: DType | None = None,
|
||||
include_initial: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
if axis is None:
|
||||
if x.ndim > 1:
|
||||
raise ValueError(
|
||||
"axis must be specified in cumulative_prod for more than one dimension"
|
||||
)
|
||||
axis = 0
|
||||
|
||||
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
|
||||
|
||||
# np.cumprod does not support include_initial
|
||||
if include_initial:
|
||||
initial_shape = list(x.shape)
|
||||
initial_shape[axis] = 1
|
||||
res = xp.concatenate(
|
||||
[
|
||||
wrapped_xp.ones(
|
||||
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
|
||||
),
|
||||
res,
|
||||
],
|
||||
axis=axis,
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
# The min and max argument names in clip are different and not optional in numpy, and type
|
||||
# promotion behavior is different.
|
||||
def clip(
|
||||
x: Array,
|
||||
/,
|
||||
min: float | Array | None = None,
|
||||
max: float | Array | None = None,
|
||||
*,
|
||||
xp: Namespace,
|
||||
# TODO: np.clip has other ufunc kwargs
|
||||
out: Array | None = None,
|
||||
) -> Array:
|
||||
def _isscalar(a: object) -> TypeIs[int | float | None]:
|
||||
return isinstance(a, (int, float, type(None)))
|
||||
|
||||
min_shape = () if _isscalar(min) else min.shape
|
||||
max_shape = () if _isscalar(max) else max.shape
|
||||
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
||||
|
||||
# np.clip does type promotion but the array API clip requires that the
|
||||
# output have the same dtype as x. We do this instead of just downcasting
|
||||
# the result of xp.clip() to handle some corner cases better (e.g.,
|
||||
# avoiding uint64 -> float64 promotion).
|
||||
|
||||
# Note: cases where min or max overflow (integer) or round (float) in the
|
||||
# wrong direction when downcasting to x.dtype are unspecified. This code
|
||||
# just does whatever NumPy does when it downcasts in the assignment, but
|
||||
# other behavior could be preferred, especially for integers. For example,
|
||||
# this code produces:
|
||||
|
||||
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
|
||||
# -128
|
||||
|
||||
# but an answer of 0 might be preferred. See
|
||||
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
|
||||
|
||||
# At least handle the case of Python integers correctly (see
|
||||
# https://github.com/numpy/numpy/pull/26892).
|
||||
if wrapped_xp.isdtype(x.dtype, "integral"):
|
||||
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
|
||||
min = None
|
||||
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
||||
max = None
|
||||
|
||||
dev = _get_device(x)
|
||||
if out is None:
|
||||
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
|
||||
assert out is not None # workaround for a type-narrowing issue in pyright
|
||||
out[()] = x
|
||||
|
||||
if min is not None:
|
||||
a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
|
||||
a = xp.broadcast_to(a, result_shape)
|
||||
ia = (out < a) | xp.isnan(a)
|
||||
out[ia] = a[ia]
|
||||
|
||||
if max is not None:
|
||||
b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
|
||||
b = xp.broadcast_to(b, result_shape)
|
||||
ib = (out > b) | xp.isnan(b)
|
||||
out[ib] = b[ib]
|
||||
|
||||
# Return a scalar for 0-D
|
||||
return out[()]
|
||||
|
||||
|
||||
# Unlike transpose(), the axes argument to permute_dims() is required.
|
||||
def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
|
||||
return xp.transpose(x, axes)
|
||||
|
||||
|
||||
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
||||
def reshape(
|
||||
x: Array,
|
||||
/,
|
||||
shape: tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
copy: Optional[bool] = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
if copy is True:
|
||||
x = x.copy()
|
||||
elif copy is False:
|
||||
y = x.view()
|
||||
y.shape = shape
|
||||
return y
|
||||
return xp.reshape(x, shape, **kwargs)
|
||||
|
||||
|
||||
# The descending keyword is new in sort and argsort, and 'kind' replaced with
|
||||
# 'stable'
|
||||
def argsort(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: bool = False,
|
||||
stable: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs["kind"] = "stable"
|
||||
if not descending:
|
||||
res = xp.argsort(x, axis=axis, **kwargs)
|
||||
else:
|
||||
# As NumPy has no native descending sort, we imitate it here. Note that
|
||||
# simply flipping the results of xp.argsort(x, ...) would not
|
||||
# respect the relative order like it would in native descending sorts.
|
||||
res = xp.flip(
|
||||
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
|
||||
axis=axis,
|
||||
)
|
||||
# Rely on flip()/argsort() to validate axis
|
||||
normalised_axis = axis if axis >= 0 else x.ndim + axis
|
||||
max_i = x.shape[normalised_axis] - 1
|
||||
res = max_i - res
|
||||
return res
|
||||
|
||||
|
||||
def sort(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: bool = False,
|
||||
stable: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs["kind"] = "stable"
|
||||
res = xp.sort(x, axis=axis, **kwargs)
|
||||
if descending:
|
||||
res = xp.flip(res, axis=axis)
|
||||
return res
|
||||
|
||||
|
||||
# nonzero should error for zero-dimensional arrays
|
||||
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
||||
return xp.nonzero(x, **kwargs)
|
||||
|
||||
|
||||
# ceil, floor, and trunc return integers for integer inputs
|
||||
|
||||
|
||||
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.ceil(x, **kwargs)
|
||||
|
||||
|
||||
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.floor(x, **kwargs)
|
||||
|
||||
|
||||
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.trunc(x, **kwargs)
|
||||
|
||||
|
||||
# linear algebra functions
|
||||
|
||||
|
||||
def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
return xp.matmul(x1, x2, **kwargs)
|
||||
|
||||
|
||||
# Unlike transpose, matrix_transpose only transposes the last two axes.
|
||||
def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
|
||||
if x.ndim < 2:
|
||||
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
|
||||
return xp.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
def tensordot(
|
||||
x1: Array,
|
||||
x2: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.tensordot(x1, x2, axes=axes, **kwargs)
|
||||
|
||||
|
||||
def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
|
||||
if x1.shape[axis] != x2.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along the given axis")
|
||||
|
||||
if hasattr(xp, "broadcast_tensors"):
|
||||
_broadcast = xp.broadcast_tensors
|
||||
else:
|
||||
_broadcast = xp.broadcast_arrays
|
||||
|
||||
x1_ = xp.moveaxis(x1, axis, -1)
|
||||
x2_ = xp.moveaxis(x2, axis, -1)
|
||||
x1_, x2_ = _broadcast(x1_, x2_)
|
||||
|
||||
res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
|
||||
return res[..., 0, 0]
|
||||
|
||||
|
||||
# isdtype is a new function in the 2022.12 array API specification.
|
||||
|
||||
|
||||
def isdtype(
|
||||
dtype: DType,
|
||||
kind: DType | str | tuple[DType | str, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
_tuple: bool = True, # Disallow nested tuples
|
||||
) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
|
||||
|
||||
Note that outside of this function, this compat library does not yet fully
|
||||
support complex numbers.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
|
||||
for more details
|
||||
"""
|
||||
if isinstance(kind, tuple) and _tuple:
|
||||
return any(
|
||||
isdtype(dtype, k, xp, _tuple=False)
|
||||
for k in cast("tuple[DType | str, ...]", kind)
|
||||
)
|
||||
elif isinstance(kind, str):
|
||||
if kind == "bool":
|
||||
return dtype == xp.bool_
|
||||
elif kind == "signed integer":
|
||||
return xp.issubdtype(dtype, xp.signedinteger)
|
||||
elif kind == "unsigned integer":
|
||||
return xp.issubdtype(dtype, xp.unsignedinteger)
|
||||
elif kind == "integral":
|
||||
return xp.issubdtype(dtype, xp.integer)
|
||||
elif kind == "real floating":
|
||||
return xp.issubdtype(dtype, xp.floating)
|
||||
elif kind == "complex floating":
|
||||
return xp.issubdtype(dtype, xp.complexfloating)
|
||||
elif kind == "numeric":
|
||||
return xp.issubdtype(dtype, xp.number)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data type kind: {kind!r}")
|
||||
else:
|
||||
# This will allow things that aren't required by the spec, like
|
||||
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
|
||||
# more strict here to match the type annotation? Note that the
|
||||
# array_api_strict implementation will be very strict.
|
||||
return dtype == kind
|
||||
|
||||
|
||||
# unstack is a new function in the 2023.12 array API standard
|
||||
def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("Input array must be at least 1-d.")
|
||||
return tuple(xp.moveaxis(x, axis, 0))
|
||||
|
||||
|
||||
# numpy 1.26 does not use the standard definition for sign on complex numbers
|
||||
|
||||
|
||||
def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
if isdtype(x.dtype, "complex floating", xp=xp):
|
||||
out = (x / xp.abs(x, **kwargs))[...]
|
||||
# sign(0) = 0 but the above formula would give nan
|
||||
out[x == 0j] = 0j
|
||||
else:
|
||||
out = xp.sign(x, **kwargs)
|
||||
# CuPy sign() does not propagate nans. See
|
||||
# https://github.com/data-apis/array-api-compat/issues/136
|
||||
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
|
||||
out[xp.isnan(x)] = xp.nan
|
||||
return out[()]
|
||||
|
||||
|
||||
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
|
||||
# It is surprisingly difficult to recognize a dtype apart from an array.
|
||||
# np.int64 is not the same as np.asarray(1).dtype!
|
||||
try:
|
||||
return xp.finfo(type_)
|
||||
except (ValueError, TypeError):
|
||||
return xp.finfo(type_.dtype)
|
||||
|
||||
|
||||
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
|
||||
try:
|
||||
return xp.iinfo(type_)
|
||||
except (ValueError, TypeError):
|
||||
return xp.iinfo(type_.dtype)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"arange",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eye",
|
||||
"full",
|
||||
"full_like",
|
||||
"linspace",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
"UniqueAllResult",
|
||||
"UniqueCountsResult",
|
||||
"UniqueInverseResult",
|
||||
"unique_all",
|
||||
"unique_counts",
|
||||
"unique_inverse",
|
||||
"unique_values",
|
||||
"std",
|
||||
"var",
|
||||
"cumulative_sum",
|
||||
"cumulative_prod",
|
||||
"clip",
|
||||
"permute_dims",
|
||||
"reshape",
|
||||
"argsort",
|
||||
"sort",
|
||||
"nonzero",
|
||||
"ceil",
|
||||
"floor",
|
||||
"trunc",
|
||||
"matmul",
|
||||
"matrix_transpose",
|
||||
"tensordot",
|
||||
"vecdot",
|
||||
"isdtype",
|
||||
"unstack",
|
||||
"sign",
|
||||
"finfo",
|
||||
"iinfo",
|
||||
]
|
||||
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
213
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_fft.py
vendored
Normal file
213
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_fft.py
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from ._typing import Array, Device, DType, Namespace
|
||||
|
||||
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
|
||||
|
||||
# Note: NumPy fft functions improperly upcast float32 and complex64 to
|
||||
# complex128, which is why we require wrapping them all here.
|
||||
|
||||
def fft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def rfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def rfftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def hfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def ihfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftfreq(
|
||||
n: int,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
d: float = 1.0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
res = xp.fft.fftfreq(n, d=d)
|
||||
if dtype is not None:
|
||||
return res.astype(dtype)
|
||||
return res
|
||||
|
||||
def rfftfreq(
|
||||
n: int,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
d: float = 1.0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
res = xp.fft.rfftfreq(n, d=d)
|
||||
if dtype is not None:
|
||||
return res.astype(dtype)
|
||||
return res
|
||||
|
||||
def fftshift(
|
||||
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
|
||||
) -> Array:
|
||||
return xp.fft.fftshift(x, axes=axes)
|
||||
|
||||
def ifftshift(
|
||||
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
|
||||
) -> Array:
|
||||
return xp.fft.ifftshift(x, axes=axes)
|
||||
|
||||
__all__ = [
|
||||
"fft",
|
||||
"ifft",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfft",
|
||||
"irfft",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
"hfft",
|
||||
"ihfft",
|
||||
"fftfreq",
|
||||
"rfftfreq",
|
||||
"fftshift",
|
||||
"ifftshift",
|
||||
]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
1058
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_helpers.py
vendored
Normal file
1058
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_helpers.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
232
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_linalg.py
vendored
Normal file
232
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_linalg.py
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Literal, NamedTuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
if np.__version__[0] == "2":
|
||||
from numpy.lib.array_utils import normalize_axis_tuple
|
||||
else:
|
||||
from numpy.core.numeric import normalize_axis_tuple
|
||||
|
||||
from .._internal import get_xp
|
||||
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
|
||||
from ._typing import Array, DType, JustFloat, JustInt, Namespace
|
||||
|
||||
|
||||
# These are in the main NumPy namespace but not in numpy.linalg
|
||||
def cross(
|
||||
x1: Array,
|
||||
x2: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.cross(x1, x2, axis=axis, **kwargs)
|
||||
|
||||
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
return xp.outer(x1, x2, **kwargs)
|
||||
|
||||
class EighResult(NamedTuple):
|
||||
eigenvalues: Array
|
||||
eigenvectors: Array
|
||||
|
||||
class QRResult(NamedTuple):
|
||||
Q: Array
|
||||
R: Array
|
||||
|
||||
class SlogdetResult(NamedTuple):
|
||||
sign: Array
|
||||
logabsdet: Array
|
||||
|
||||
class SVDResult(NamedTuple):
|
||||
U: Array
|
||||
S: Array
|
||||
Vh: Array
|
||||
|
||||
# These functions are the same as their NumPy counterparts except they return
|
||||
# a namedtuple.
|
||||
def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
|
||||
return EighResult(*xp.linalg.eigh(x, **kwargs))
|
||||
|
||||
def qr(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
mode: Literal["reduced", "complete"] = "reduced",
|
||||
**kwargs: object,
|
||||
) -> QRResult:
|
||||
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
|
||||
|
||||
def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
|
||||
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
|
||||
|
||||
def svd(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
**kwargs: object,
|
||||
) -> SVDResult:
|
||||
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
|
||||
|
||||
# These functions have additional keyword arguments
|
||||
|
||||
# The upper keyword argument is new from NumPy
|
||||
def cholesky(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
upper: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
L = xp.linalg.cholesky(x, **kwargs)
|
||||
if upper:
|
||||
U = get_xp(xp)(matrix_transpose)(L)
|
||||
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
|
||||
U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
|
||||
return U
|
||||
return L
|
||||
|
||||
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
|
||||
# Note that it has a different semantic meaning from tol and rcond.
|
||||
def matrix_rank(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
rtol: float | Array | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# this is different from xp.linalg.matrix_rank, which supports 1
|
||||
# dimensional arrays.
|
||||
if x.ndim < 2:
|
||||
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
|
||||
S: Array = get_xp(xp)(svdvals)(x, **kwargs)
|
||||
if rtol is None:
|
||||
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
|
||||
else:
|
||||
# this is different from xp.linalg.matrix_rank, which does not
|
||||
# multiply the tolerance by the largest singular value.
|
||||
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
|
||||
return xp.count_nonzero(S > tol, axis=-1)
|
||||
|
||||
def pinv(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
rtol: float | Array | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# this is different from xp.linalg.pinv, which does not multiply the
|
||||
# default tolerance by max(M, N).
|
||||
if rtol is None:
|
||||
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
|
||||
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
|
||||
|
||||
# These functions are new in the array API spec
|
||||
|
||||
def matrix_norm(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
keepdims: bool = False,
|
||||
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
|
||||
) -> Array:
|
||||
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
||||
|
||||
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
|
||||
# xp.linalg.svd(compute_uv=False).
|
||||
def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
|
||||
return xp.linalg.svd(x, compute_uv=False)
|
||||
|
||||
def vector_norm(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
keepdims: bool = False,
|
||||
ord: JustInt | JustFloat = 2,
|
||||
) -> Array:
|
||||
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
|
||||
# when axis=None and the input is 2-D, so to force a vector norm, we make
|
||||
# it so the input is 1-D (for axis=None), or reshape so that norm is done
|
||||
# on a single dimension.
|
||||
if axis is None:
|
||||
# Note: xp.linalg.norm() doesn't handle 0-D arrays
|
||||
_x = x.ravel()
|
||||
_axis = 0
|
||||
elif isinstance(axis, tuple):
|
||||
# Note: The axis argument supports any number of axes, whereas
|
||||
# xp.linalg.norm() only supports a single axis for vector norm.
|
||||
normalized_axis = cast(
|
||||
"tuple[int, ...]",
|
||||
normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
|
||||
newshape = axis + rest
|
||||
_x = xp.transpose(x, newshape).reshape(
|
||||
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
|
||||
_axis = 0
|
||||
else:
|
||||
_x = x
|
||||
_axis = axis
|
||||
|
||||
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
|
||||
|
||||
if keepdims:
|
||||
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
|
||||
# above to avoid matrix norm logic.
|
||||
shape = list(x.shape)
|
||||
_axis = cast(
|
||||
"tuple[int, ...]",
|
||||
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
|
||||
range(x.ndim) if axis is None else axis,
|
||||
x.ndim,
|
||||
),
|
||||
)
|
||||
for i in _axis:
|
||||
shape[i] = 1
|
||||
res = xp.reshape(res, tuple(shape))
|
||||
|
||||
return res
|
||||
|
||||
# xp.diagonal and xp.trace operate on the first two axes whereas these
|
||||
# operates on the last two
|
||||
|
||||
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
|
||||
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
|
||||
|
||||
def trace(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
offset: int = 0,
|
||||
dtype: DType | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.asarray(
|
||||
xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
|
||||
)
|
||||
|
||||
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
|
||||
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
|
||||
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
|
||||
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
|
||||
'trace']
|
||||
|
||||
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
192
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_typing.py
vendored
Normal file
192
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/common/_typing.py
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from types import ModuleType as Namespace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
final,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import Incomplete
|
||||
|
||||
SupportsBufferProtocol: TypeAlias = Incomplete
|
||||
Array: TypeAlias = Incomplete
|
||||
Device: TypeAlias = Incomplete
|
||||
DType: TypeAlias = Incomplete
|
||||
else:
|
||||
SupportsBufferProtocol = object
|
||||
Array = object
|
||||
Device = object
|
||||
DType = object
|
||||
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
# These "Just" types are equivalent to the `Just` type from the `optype` library,
|
||||
# apart from them not being `@runtime_checkable`.
|
||||
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
|
||||
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
|
||||
@final
|
||||
class JustInt(Protocol):
|
||||
@property
|
||||
def __class__(self, /) -> type[int]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
@final
|
||||
class JustFloat(Protocol):
|
||||
@property
|
||||
def __class__(self, /) -> type[float]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
@final
|
||||
class JustComplex(Protocol):
|
||||
@property
|
||||
def __class__(self, /) -> type[complex]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
class NestedSequence(Protocol[_T_co]):
|
||||
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
|
||||
def __len__(self, /) -> int: ...
|
||||
|
||||
|
||||
class SupportsArrayNamespace(Protocol[_T_co]):
|
||||
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
|
||||
|
||||
|
||||
class HasShape(Protocol[_T_co]):
|
||||
@property
|
||||
def shape(self, /) -> _T_co: ...
|
||||
|
||||
|
||||
# Return type of `__array_namespace_info__.default_dtypes`
|
||||
Capabilities = TypedDict(
|
||||
"Capabilities",
|
||||
{
|
||||
"boolean indexing": bool,
|
||||
"data-dependent shapes": bool,
|
||||
"max dimensions": int,
|
||||
},
|
||||
)
|
||||
|
||||
# Return type of `__array_namespace_info__.default_dtypes`
|
||||
DefaultDTypes = TypedDict(
|
||||
"DefaultDTypes",
|
||||
{
|
||||
"real floating": DType,
|
||||
"complex floating": DType,
|
||||
"integral": DType,
|
||||
"indexing": DType,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_DTypeKind: TypeAlias = Literal[
|
||||
"bool",
|
||||
"signed integer",
|
||||
"unsigned integer",
|
||||
"integral",
|
||||
"real floating",
|
||||
"complex floating",
|
||||
"numeric",
|
||||
]
|
||||
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
|
||||
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="bool")`
|
||||
class DTypesBool(TypedDict):
|
||||
bool: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="signed integer")`
|
||||
class DTypesSigned(TypedDict):
|
||||
int8: DType
|
||||
int16: DType
|
||||
int32: DType
|
||||
int64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
|
||||
class DTypesUnsigned(TypedDict):
|
||||
uint8: DType
|
||||
uint16: DType
|
||||
uint32: DType
|
||||
uint64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="integral")`
|
||||
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="real floating")`
|
||||
class DTypesReal(TypedDict):
|
||||
float32: DType
|
||||
float64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="complex floating")`
|
||||
class DTypesComplex(TypedDict):
|
||||
complex64: DType
|
||||
complex128: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="numeric")`
|
||||
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind=None)` (default)
|
||||
class DTypesAll(DTypesBool, DTypesNumeric):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
|
||||
DTypesAny: TypeAlias = Mapping[str, DType]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Array",
|
||||
"Capabilities",
|
||||
"DType",
|
||||
"DTypeKind",
|
||||
"DTypesAny",
|
||||
"DTypesAll",
|
||||
"DTypesBool",
|
||||
"DTypesNumeric",
|
||||
"DTypesIntegral",
|
||||
"DTypesSigned",
|
||||
"DTypesUnsigned",
|
||||
"DTypesReal",
|
||||
"DTypesComplex",
|
||||
"DefaultDTypes",
|
||||
"Device",
|
||||
"HasShape",
|
||||
"Namespace",
|
||||
"JustInt",
|
||||
"JustFloat",
|
||||
"JustComplex",
|
||||
"NestedSequence",
|
||||
"SupportsArrayNamespace",
|
||||
"SupportsBufferProtocol",
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
13
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/__init__.py
vendored
Normal file
13
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
from cupy import * # noqa: F403
|
||||
|
||||
# from cupy import * doesn't overwrite these builtin names
|
||||
from cupy import abs, max, min, round # noqa: F401
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
__array_api_version__ = '2024.12'
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
156
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_aliases.py
vendored
Normal file
156
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_aliases.py
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import cupy as cp
|
||||
|
||||
from ..common import _aliases, _helpers
|
||||
from ..common._typing import NestedSequence, SupportsBufferProtocol
|
||||
from .._internal import get_xp
|
||||
from ._info import __array_namespace_info__
|
||||
from ._typing import Array, Device, DType
|
||||
|
||||
bool = cp.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = cp.arccos
|
||||
acosh = cp.arccosh
|
||||
asin = cp.arcsin
|
||||
asinh = cp.arcsinh
|
||||
atan = cp.arctan
|
||||
atan2 = cp.arctan2
|
||||
atanh = cp.arctanh
|
||||
bitwise_left_shift = cp.left_shift
|
||||
bitwise_invert = cp.invert
|
||||
bitwise_right_shift = cp.right_shift
|
||||
concat = cp.concatenate
|
||||
pow = cp.power
|
||||
|
||||
arange = get_xp(cp)(_aliases.arange)
|
||||
empty = get_xp(cp)(_aliases.empty)
|
||||
empty_like = get_xp(cp)(_aliases.empty_like)
|
||||
eye = get_xp(cp)(_aliases.eye)
|
||||
full = get_xp(cp)(_aliases.full)
|
||||
full_like = get_xp(cp)(_aliases.full_like)
|
||||
linspace = get_xp(cp)(_aliases.linspace)
|
||||
ones = get_xp(cp)(_aliases.ones)
|
||||
ones_like = get_xp(cp)(_aliases.ones_like)
|
||||
zeros = get_xp(cp)(_aliases.zeros)
|
||||
zeros_like = get_xp(cp)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(cp)(_aliases.unique_all)
|
||||
unique_counts = get_xp(cp)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(cp)(_aliases.unique_values)
|
||||
std = get_xp(cp)(_aliases.std)
|
||||
var = get_xp(cp)(_aliases.var)
|
||||
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
|
||||
clip = get_xp(cp)(_aliases.clip)
|
||||
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
||||
reshape = get_xp(cp)(_aliases.reshape)
|
||||
argsort = get_xp(cp)(_aliases.argsort)
|
||||
sort = get_xp(cp)(_aliases.sort)
|
||||
nonzero = get_xp(cp)(_aliases.nonzero)
|
||||
ceil = get_xp(cp)(_aliases.ceil)
|
||||
floor = get_xp(cp)(_aliases.floor)
|
||||
trunc = get_xp(cp)(_aliases.trunc)
|
||||
matmul = get_xp(cp)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(cp)(_aliases.tensordot)
|
||||
sign = get_xp(cp)(_aliases.sign)
|
||||
finfo = get_xp(cp)(_aliases.finfo)
|
||||
iinfo = get_xp(cp)(_aliases.iinfo)
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
def asarray(
|
||||
obj: (
|
||||
Array
|
||||
| bool | int | float | complex
|
||||
| NestedSequence[bool | int | float | complex]
|
||||
| SupportsBufferProtocol
|
||||
),
|
||||
/,
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
copy: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
with cp.cuda.Device(device):
|
||||
if copy is None:
|
||||
return cp.asarray(obj, dtype=dtype, **kwargs)
|
||||
else:
|
||||
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
|
||||
if not copy and res is not obj:
|
||||
raise ValueError("Unable to avoid copy while creating an array as requested")
|
||||
return res
|
||||
|
||||
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: bool = True,
|
||||
device: Optional[Device] = None,
|
||||
) -> Array:
|
||||
if device is None:
|
||||
return x.astype(dtype=dtype, copy=copy)
|
||||
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
|
||||
return out.copy() if copy and out is x else out
|
||||
|
||||
|
||||
# cupy.count_nonzero does not have keepdims
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis=None,
|
||||
keepdims=False
|
||||
) -> Array:
|
||||
result = cp.count_nonzero(x, axis)
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
return cp.reshape(result, [1]*x.ndim)
|
||||
return cp.expand_dims(result, axis)
|
||||
return result
|
||||
|
||||
|
||||
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
|
||||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
|
||||
return cp.take_along_axis(x, indices, axis=axis)
|
||||
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp, 'vecdot'):
|
||||
vecdot = cp.vecdot
|
||||
else:
|
||||
vecdot = get_xp(cp)(_aliases.vecdot)
|
||||
|
||||
if hasattr(cp, 'isdtype'):
|
||||
isdtype = cp.isdtype
|
||||
else:
|
||||
isdtype = get_xp(cp)(_aliases.isdtype)
|
||||
|
||||
if hasattr(cp, 'unstack'):
|
||||
unstack = cp.unstack
|
||||
else:
|
||||
unstack = get_xp(cp)(_aliases.unstack)
|
||||
|
||||
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
|
||||
'acos', 'acosh', 'asin', 'asinh', 'atan',
|
||||
'atan2', 'atanh', 'bitwise_left_shift',
|
||||
'bitwise_invert', 'bitwise_right_shift',
|
||||
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
|
||||
'take_along_axis']
|
||||
|
||||
_all_ignore = ['cp', 'get_xp']
|
||||
336
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_info.py
vendored
Normal file
336
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_info.py
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
from cupy import (
|
||||
dtype,
|
||||
cuda,
|
||||
bool_ as bool,
|
||||
intp,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
complex64,
|
||||
complex128,
|
||||
)
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for CuPy.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for CuPy.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': cupy.float64,
|
||||
'complex floating': cupy.complex128,
|
||||
'integral': cupy.int64,
|
||||
'indexing': cupy.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = 'cupy'
|
||||
|
||||
def capabilities(self):
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing. Always ``True`` for CuPy.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes. Always ``True`` for
|
||||
CuPy.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self):
|
||||
"""
|
||||
The default device used for new CuPy arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new CuPy arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
Device(0)
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method returns the static default device when CuPy is initialized.
|
||||
However, the *current* device used by creation functions (``empty`` etc.)
|
||||
can be changed globally or with a context manager.
|
||||
|
||||
See Also
|
||||
--------
|
||||
https://github.com/data-apis/array-api/issues/835
|
||||
"""
|
||||
return cuda.Device(0)
|
||||
|
||||
def default_dtypes(self, *, device=None):
|
||||
"""
|
||||
The default data types used for new CuPy arrays.
|
||||
|
||||
For CuPy, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``cupy.float64``
|
||||
- **"complex floating"**: ``cupy.complex128``
|
||||
- **"integral"**: ``cupy.intp``
|
||||
- **"indexing"**: ``cupy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new CuPy
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': cupy.float64,
|
||||
'complex floating': cupy.complex128,
|
||||
'integral': cupy.int64,
|
||||
'indexing': cupy.int64}
|
||||
|
||||
"""
|
||||
# TODO: Does this depend on device?
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
def dtypes(self, *, device=None, kind=None):
|
||||
"""
|
||||
The array API data types supported by CuPy.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
CuPy data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': cupy.int8,
|
||||
'int16': cupy.int16,
|
||||
'int32': cupy.int32,
|
||||
'int64': cupy.int64}
|
||||
|
||||
"""
|
||||
# TODO: Does this depend on device?
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": bool}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self):
|
||||
"""
|
||||
The devices supported by CuPy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by CuPy.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
"""
|
||||
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]
|
||||
31
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_typing.py
vendored
Normal file
31
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/_typing.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["Array", "DType", "Device"]
|
||||
_all_ignore = ["cp"]
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cupy as cp
|
||||
from cupy import ndarray as Array
|
||||
from cupy.cuda.device import Device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
|
||||
DType = cp.dtype[
|
||||
cp.intp
|
||||
| cp.int8
|
||||
| cp.int16
|
||||
| cp.int32
|
||||
| cp.int64
|
||||
| cp.uint8
|
||||
| cp.uint16
|
||||
| cp.uint32
|
||||
| cp.uint64
|
||||
| cp.float32
|
||||
| cp.float64
|
||||
| cp.complex64
|
||||
| cp.complex128
|
||||
| cp.bool_
|
||||
]
|
||||
else:
|
||||
DType = cp.dtype
|
||||
36
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/fft.py
vendored
Normal file
36
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/fft.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
from cupy.fft import * # noqa: F403
|
||||
# cupy.fft doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.fft import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from cupy.fft import *', _n)
|
||||
del _n['__builtins__']
|
||||
fft_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _fft
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
fft = get_xp(cp)(_fft.fft)
|
||||
ifft = get_xp(cp)(_fft.ifft)
|
||||
fftn = get_xp(cp)(_fft.fftn)
|
||||
ifftn = get_xp(cp)(_fft.ifftn)
|
||||
rfft = get_xp(cp)(_fft.rfft)
|
||||
irfft = get_xp(cp)(_fft.irfft)
|
||||
rfftn = get_xp(cp)(_fft.rfftn)
|
||||
irfftn = get_xp(cp)(_fft.irfftn)
|
||||
hfft = get_xp(cp)(_fft.hfft)
|
||||
ihfft = get_xp(cp)(_fft.ihfft)
|
||||
fftfreq = get_xp(cp)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(cp)(_fft.rfftfreq)
|
||||
fftshift = get_xp(cp)(_fft.fftshift)
|
||||
ifftshift = get_xp(cp)(_fft.ifftshift)
|
||||
|
||||
__all__ = fft_all + _fft.__all__
|
||||
|
||||
del get_xp
|
||||
del cp
|
||||
del fft_all
|
||||
del _fft
|
||||
49
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/linalg.py
vendored
Normal file
49
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/cupy/linalg.py
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
from cupy.linalg import * # noqa: F403
|
||||
# cupy.linalg doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.linalg import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from cupy.linalg import *', _n)
|
||||
del _n['__builtins__']
|
||||
linalg_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _linalg
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
|
||||
|
||||
cross = get_xp(cp)(_linalg.cross)
|
||||
outer = get_xp(cp)(_linalg.outer)
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
eigh = get_xp(cp)(_linalg.eigh)
|
||||
qr = get_xp(cp)(_linalg.qr)
|
||||
slogdet = get_xp(cp)(_linalg.slogdet)
|
||||
svd = get_xp(cp)(_linalg.svd)
|
||||
cholesky = get_xp(cp)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
|
||||
pinv = get_xp(cp)(_linalg.pinv)
|
||||
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
|
||||
svdvals = get_xp(cp)(_linalg.svdvals)
|
||||
diagonal = get_xp(cp)(_linalg.diagonal)
|
||||
trace = get_xp(cp)(_linalg.trace)
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp.linalg, 'vector_norm'):
|
||||
vector_norm = cp.linalg.vector_norm
|
||||
else:
|
||||
vector_norm = get_xp(cp)(_linalg.vector_norm)
|
||||
|
||||
__all__ = linalg_all + _linalg.__all__
|
||||
|
||||
del get_xp
|
||||
del cp
|
||||
del linalg_all
|
||||
del _linalg
|
||||
0
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/__init__.py
vendored
Normal file
0
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/__init__.py
vendored
Normal file
Binary file not shown.
12
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/__init__.py
vendored
Normal file
12
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Final
|
||||
|
||||
from dask.array import * # noqa: F403
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
__array_api_version__: Final = "2024.12"
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
__import__(__package__ + '.fft')
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
376
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/_aliases.py
vendored
Normal file
376
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/_aliases.py
vendored
Normal file
@@ -0,0 +1,376 @@
|
||||
# pyright: reportPrivateUsage=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import bool as py_bool
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
can_cast,
|
||||
complex64,
|
||||
complex128,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
result_type,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ..._internal import get_xp
|
||||
from ...common import _aliases, _helpers, array_namespace
|
||||
from ...common._typing import (
|
||||
Array,
|
||||
Device,
|
||||
DType,
|
||||
NestedSequence,
|
||||
SupportsBufferProtocol,
|
||||
)
|
||||
from ._info import __array_namespace_info__
|
||||
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
unstack = get_xp(da)(_aliases.unstack)
|
||||
|
||||
|
||||
# da.astype doesn't respect copy=True
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: py_bool = True,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for astype().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
if not copy and dtype == x.dtype:
|
||||
return x
|
||||
x = x.astype(dtype)
|
||||
return x.copy() if copy else x
|
||||
|
||||
|
||||
# Common aliases
|
||||
|
||||
|
||||
# This arange func is modified from the common one to
|
||||
# not pass stop/step as keyword arguments, which will cause
|
||||
# an error with dask
|
||||
def arange(
|
||||
start: float,
|
||||
/,
|
||||
stop: float | None = None,
|
||||
step: float = 1,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for arange().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
args: list[Any] = [start]
|
||||
if stop is not None:
|
||||
args.append(stop)
|
||||
else:
|
||||
# stop is None, so start is actually stop
|
||||
# prepend the default value for start which is 0
|
||||
args.insert(0, 0)
|
||||
args.append(step)
|
||||
|
||||
return da.arange(*args, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
eye = get_xp(da)(_aliases.eye)
|
||||
linspace = get_xp(da)(_aliases.linspace)
|
||||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(da)(_aliases.unique_all)
|
||||
unique_counts = get_xp(da)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(da)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(da)(_aliases.unique_values)
|
||||
permute_dims = get_xp(da)(_aliases.permute_dims)
|
||||
std = get_xp(da)(_aliases.std)
|
||||
var = get_xp(da)(_aliases.var)
|
||||
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
|
||||
empty = get_xp(da)(_aliases.empty)
|
||||
empty_like = get_xp(da)(_aliases.empty_like)
|
||||
full = get_xp(da)(_aliases.full)
|
||||
full_like = get_xp(da)(_aliases.full_like)
|
||||
ones = get_xp(da)(_aliases.ones)
|
||||
ones_like = get_xp(da)(_aliases.ones_like)
|
||||
zeros = get_xp(da)(_aliases.zeros)
|
||||
zeros_like = get_xp(da)(_aliases.zeros_like)
|
||||
reshape = get_xp(da)(_aliases.reshape)
|
||||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
|
||||
vecdot = get_xp(da)(_aliases.vecdot)
|
||||
nonzero = get_xp(da)(_aliases.nonzero)
|
||||
ceil = get_xp(np)(_aliases.ceil)
|
||||
floor = get_xp(np)(_aliases.floor)
|
||||
trunc = get_xp(np)(_aliases.trunc)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
sign = get_xp(np)(_aliases.sign)
|
||||
finfo = get_xp(np)(_aliases.finfo)
|
||||
iinfo = get_xp(np)(_aliases.iinfo)
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
def asarray(
|
||||
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: py_bool | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
if isinstance(obj, da.Array):
|
||||
if dtype is not None and dtype != obj.dtype:
|
||||
if copy is False:
|
||||
raise ValueError("Unable to avoid copy when changing dtype")
|
||||
obj = obj.astype(dtype)
|
||||
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
if copy is False:
|
||||
raise ValueError(
|
||||
"Unable to avoid copy when converting a non-dask object to dask"
|
||||
)
|
||||
|
||||
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
|
||||
# see https://github.com/dask/dask/pull/11524/
|
||||
obj = np.array(obj, dtype=dtype, copy=True)
|
||||
return da.from_array(obj)
|
||||
|
||||
|
||||
# Element wise aliases
|
||||
from dask.array import arccos as acos
|
||||
from dask.array import arccosh as acosh
|
||||
from dask.array import arcsin as asin
|
||||
from dask.array import arcsinh as asinh
|
||||
from dask.array import arctan as atan
|
||||
from dask.array import arctan2 as atan2
|
||||
from dask.array import arctanh as atanh
|
||||
|
||||
# Other
|
||||
from dask.array import concatenate as concat
|
||||
from dask.array import invert as bitwise_invert
|
||||
from dask.array import left_shift as bitwise_left_shift
|
||||
from dask.array import power as pow
|
||||
from dask.array import right_shift as bitwise_right_shift
|
||||
|
||||
|
||||
# dask.array.clip does not work unless all three arguments are provided.
|
||||
# Furthermore, the masking workaround in common._aliases.clip cannot work with
|
||||
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
|
||||
# now).
|
||||
def clip(
|
||||
x: Array,
|
||||
/,
|
||||
min: float | Array | None = None,
|
||||
max: float | Array | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for clip().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
|
||||
def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
|
||||
return a is None or isinstance(a, (int, float))
|
||||
|
||||
min_shape = () if _isscalar(min) else min.shape
|
||||
max_shape = () if _isscalar(max) else max.shape
|
||||
|
||||
# TODO: This won't handle dask unknown shapes
|
||||
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
|
||||
|
||||
if min is not None:
|
||||
min = da.broadcast_to(da.asarray(min), result_shape)
|
||||
if max is not None:
|
||||
max = da.broadcast_to(da.asarray(max), result_shape)
|
||||
|
||||
if min is None and max is None:
|
||||
return da.positive(x)
|
||||
|
||||
if min is None:
|
||||
return astype(da.minimum(x, max), x.dtype)
|
||||
if max is None:
|
||||
return astype(da.maximum(x, min), x.dtype)
|
||||
|
||||
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
|
||||
|
||||
|
||||
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
|
||||
"""
|
||||
Make sure that Array is not broken into multiple chunks along axis.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : Array
|
||||
The input Array with a single chunk along axis.
|
||||
restore : Callable[Array, Array]
|
||||
function to apply to the output to rechunk it back into reasonable chunks
|
||||
"""
|
||||
if axis < 0:
|
||||
axis += x.ndim
|
||||
if x.numblocks[axis] < 2:
|
||||
return x, lambda x: x
|
||||
|
||||
# Break chunks on other axes in an attempt to keep chunk size low
|
||||
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
|
||||
|
||||
# Rather than reconstructing the original chunks, which can be a
|
||||
# very expensive affair, just break down oversized chunks without
|
||||
# incurring in any transfers over the network.
|
||||
# This has the downside of a risk of overchunking if the array is
|
||||
# then used in operations against other arrays that match the
|
||||
# original chunking pattern.
|
||||
return x, lambda x: x.rechunk()
|
||||
|
||||
|
||||
def sort(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: py_bool = False,
|
||||
stable: py_bool = True,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility layer around the lack of sort() in Dask.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
This function temporarily rechunks the array along `axis` to a single chunk.
|
||||
This can be extremely inefficient and can lead to out-of-memory errors.
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
x, restore = _ensure_single_chunk(x, axis)
|
||||
|
||||
meta_xp = array_namespace(x._meta)
|
||||
x = da.map_blocks(
|
||||
meta_xp.sort,
|
||||
x,
|
||||
axis=axis,
|
||||
meta=x._meta,
|
||||
dtype=x.dtype,
|
||||
descending=descending,
|
||||
stable=stable,
|
||||
)
|
||||
|
||||
return restore(x)
|
||||
|
||||
|
||||
def argsort(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: py_bool = False,
|
||||
stable: py_bool = True,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility layer around the lack of argsort() in Dask.
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
This function temporarily rechunks the array along `axis` into a single chunk.
|
||||
This can be extremely inefficient and can lead to out-of-memory errors.
|
||||
"""
|
||||
x, restore = _ensure_single_chunk(x, axis)
|
||||
|
||||
meta_xp = array_namespace(x._meta)
|
||||
dtype = meta_xp.argsort(x._meta).dtype
|
||||
meta = meta_xp.astype(x._meta, dtype)
|
||||
x = da.map_blocks(
|
||||
meta_xp.argsort,
|
||||
x,
|
||||
axis=axis,
|
||||
meta=meta,
|
||||
dtype=dtype,
|
||||
descending=descending,
|
||||
stable=stable,
|
||||
)
|
||||
|
||||
return restore(x)
|
||||
|
||||
|
||||
# dask.array.count_nonzero does not have keepdims
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis: int | None = None,
|
||||
keepdims: py_bool = False,
|
||||
) -> Array:
|
||||
result = da.count_nonzero(x, axis)
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
return da.reshape(result, [1] * x.ndim)
|
||||
return da.expand_dims(result, axis)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"__array_namespace_info__",
|
||||
"count_nonzero",
|
||||
"bool",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "uint16", "uint32", "uint64",
|
||||
"float32", "float64",
|
||||
"complex64", "complex128",
|
||||
"asarray", "astype", "can_cast", "result_type",
|
||||
"pow",
|
||||
"concat",
|
||||
"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
|
||||
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
|
||||
] # fmt: skip
|
||||
__all__ += _aliases.__all__
|
||||
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
416
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/_info.py
vendored
Normal file
416
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/_info.py
vendored
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal as L
|
||||
from typing import TypeAlias, overload
|
||||
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
complex64,
|
||||
complex128,
|
||||
dtype,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
intp,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ...common._helpers import _DASK_DEVICE, _dask_device
|
||||
from ...common._typing import (
|
||||
Capabilities,
|
||||
DefaultDTypes,
|
||||
DType,
|
||||
DTypeKind,
|
||||
DTypesAll,
|
||||
DTypesAny,
|
||||
DTypesBool,
|
||||
DTypesComplex,
|
||||
DTypesIntegral,
|
||||
DTypesNumeric,
|
||||
DTypesReal,
|
||||
DTypesSigned,
|
||||
DTypesUnsigned,
|
||||
)
|
||||
|
||||
_Device: TypeAlias = L["cpu"] | _dask_device
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for Dask.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for Dask.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': dask.float64,
|
||||
'complex floating': dask.complex128,
|
||||
'integral': dask.int64,
|
||||
'indexing': dask.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = "dask.array"
|
||||
|
||||
def capabilities(self) -> Capabilities:
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing.
|
||||
|
||||
Dask support boolean indexing as long as both the index
|
||||
and the indexed arrays have known shapes.
|
||||
Note however that the output .shape and .size properties
|
||||
will contain a non-compliant math.nan instead of None.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes.
|
||||
|
||||
Dask implements unique_values et.al.
|
||||
Note however that the output .shape and .size properties
|
||||
will contain a non-compliant math.nan instead of None.
|
||||
|
||||
- **"max dimensions"**: integer indicating the maximum number of
|
||||
dimensions supported by the array library.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self) -> L["cpu"]:
|
||||
"""
|
||||
The default device used for new Dask arrays.
|
||||
|
||||
For Dask, this always returns ``'cpu'``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new Dask arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
'cpu'
|
||||
|
||||
"""
|
||||
return "cpu"
|
||||
|
||||
def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
|
||||
"""
|
||||
The default data types used for new Dask arrays.
|
||||
|
||||
For Dask, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``numpy.float64``
|
||||
- **"complex floating"**: ``numpy.complex128``
|
||||
- **"integral"**: ``numpy.intp``
|
||||
- **"indexing"**: ``numpy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new Dask
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': dask.float64,
|
||||
'complex floating': dask.complex128,
|
||||
'integral': dask.int64,
|
||||
'indexing': dask.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", _DASK_DEVICE, None]:
|
||||
raise ValueError(
|
||||
f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
|
||||
f"but received: {device!r}"
|
||||
)
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: None = None
|
||||
) -> DTypesAll: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["bool"]
|
||||
) -> DTypesBool: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["signed integer"]
|
||||
) -> DTypesSigned: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
|
||||
) -> DTypesUnsigned: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["integral"]
|
||||
) -> DTypesIntegral: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["real floating"]
|
||||
) -> DTypesReal: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["complex floating"]
|
||||
) -> DTypesComplex: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: L["numeric"]
|
||||
) -> DTypesNumeric: ...
|
||||
def dtypes(
|
||||
self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
|
||||
) -> DTypesAny:
|
||||
"""
|
||||
The array API data types supported by Dask.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
Dask data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': dask.int8,
|
||||
'int16': dask.int16,
|
||||
'int32': dask.int32,
|
||||
'int64': dask.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", _DASK_DEVICE, None]:
|
||||
raise ValueError(
|
||||
'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
|
||||
f" {device}"
|
||||
)
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": bool}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall]
|
||||
res: dict[str, DType] = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self) -> list[_Device]:
|
||||
"""
|
||||
The devices supported by Dask.
|
||||
|
||||
For Dask, this always returns ``['cpu', DASK_DEVICE]``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by Dask.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.devices()
|
||||
['cpu', DASK_DEVICE]
|
||||
|
||||
"""
|
||||
return ["cpu", _DASK_DEVICE]
|
||||
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/fft.py
vendored
Normal file
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/fft.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
from dask.array.fft import * # noqa: F403
|
||||
# dask.array.fft doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from dask.array.fft import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from dask.array.fft import *', _n)
|
||||
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
|
||||
_n.pop(k, None)
|
||||
fft_all = list(_n)
|
||||
del _n, k
|
||||
|
||||
from ...common import _fft
|
||||
from ..._internal import get_xp
|
||||
|
||||
import dask.array as da
|
||||
|
||||
fftfreq = get_xp(da)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(da)(_fft.rfftfreq)
|
||||
|
||||
__all__ = fft_all + ["fftfreq", "rfftfreq"]
|
||||
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]
|
||||
72
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/linalg.py
vendored
Normal file
72
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/dask/array/linalg.py
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import dask.array as da
|
||||
|
||||
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
|
||||
from dask.array import matmul, outer, tensordot
|
||||
|
||||
# Exports
|
||||
from dask.array.linalg import * # noqa: F403
|
||||
|
||||
from ..._internal import get_xp
|
||||
from ...common import _linalg
|
||||
from ...common._typing import Array as _Array
|
||||
from ._aliases import matrix_transpose, vecdot
|
||||
|
||||
# dask.array.linalg doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from dask.array.linalg import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from dask.array.linalg import *', _n)
|
||||
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
|
||||
_n.pop(k, None)
|
||||
linalg_all = list(_n)
|
||||
del _n, k
|
||||
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
# TODO: use the QR wrapper once dask
|
||||
# supports the mode keyword on QR
|
||||
# https://github.com/dask/dask/issues/10388
|
||||
#qr = get_xp(da)(_linalg.qr)
|
||||
def qr(
|
||||
x: _Array,
|
||||
mode: Literal["reduced", "complete"] = "reduced",
|
||||
**kwargs: object,
|
||||
) -> QRResult:
|
||||
if mode != "reduced":
|
||||
raise ValueError("dask arrays only support using mode='reduced'")
|
||||
return QRResult(*da.linalg.qr(x, **kwargs))
|
||||
trace = get_xp(da)(_linalg.trace)
|
||||
cholesky = get_xp(da)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(da)(_linalg.matrix_rank)
|
||||
matrix_norm = get_xp(da)(_linalg.matrix_norm)
|
||||
|
||||
|
||||
# Wrap the svd functions to not pass full_matrices to dask
|
||||
# when full_matrices=False (as that is the default behavior for dask),
|
||||
# and dask doesn't have the full_matrices keyword
|
||||
def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
|
||||
if full_matrices:
|
||||
raise ValueError("full_matrics=True is not supported by dask.")
|
||||
return da.linalg.svd(x, coerce_signs=False, **kwargs)
|
||||
|
||||
def svdvals(x: _Array) -> _Array:
|
||||
# TODO: can't avoid computing U or V for dask
|
||||
_, s, _ = svd(x)
|
||||
return s
|
||||
|
||||
vector_norm = get_xp(da)(_linalg.vector_norm)
|
||||
diagonal = get_xp(da)(_linalg.diagonal)
|
||||
|
||||
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
|
||||
"matrix_transpose", "vecdot", "EighResult",
|
||||
"QRResult", "SlogdetResult", "SVDResult", "qr",
|
||||
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
|
||||
"vector_norm", "diagonal"]
|
||||
|
||||
_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']
|
||||
28
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/__init__.py
vendored
Normal file
28
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/__init__.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
# ruff: noqa: PLC0414
|
||||
from typing import Final
|
||||
|
||||
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
|
||||
|
||||
# from numpy import * doesn't overwrite these builtin names
|
||||
from numpy import abs as abs
|
||||
from numpy import max as max
|
||||
from numpy import min as min
|
||||
from numpy import round as round
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# Don't know why, but we have to do an absolute import to import linalg. If we
|
||||
# instead do
|
||||
#
|
||||
# from . import linalg
|
||||
#
|
||||
# It doesn't overwrite np.linalg from above. The import is generated
|
||||
# dynamically so that the library can be vendored.
|
||||
__import__(__package__ + ".linalg")
|
||||
|
||||
__import__(__package__ + ".fft")
|
||||
|
||||
from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
|
||||
|
||||
__array_api_version__: Final = "2024.12"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
190
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_aliases.py
vendored
Normal file
190
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_aliases.py
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
# pyright: reportPrivateUsage=false
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import bool as py_bool
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._internal import get_xp
|
||||
from ..common import _aliases, _helpers
|
||||
from ..common._typing import NestedSequence, SupportsBufferProtocol
|
||||
from ._info import __array_namespace_info__
|
||||
from ._typing import Array, Device, DType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Buffer, TypeIs
|
||||
|
||||
# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
|
||||
# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
|
||||
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
|
||||
|
||||
bool = np.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = np.arccos
|
||||
acosh = np.arccosh
|
||||
asin = np.arcsin
|
||||
asinh = np.arcsinh
|
||||
atan = np.arctan
|
||||
atan2 = np.arctan2
|
||||
atanh = np.arctanh
|
||||
bitwise_left_shift = np.left_shift
|
||||
bitwise_invert = np.invert
|
||||
bitwise_right_shift = np.right_shift
|
||||
concat = np.concatenate
|
||||
pow = np.power
|
||||
|
||||
arange = get_xp(np)(_aliases.arange)
|
||||
empty = get_xp(np)(_aliases.empty)
|
||||
empty_like = get_xp(np)(_aliases.empty_like)
|
||||
eye = get_xp(np)(_aliases.eye)
|
||||
full = get_xp(np)(_aliases.full)
|
||||
full_like = get_xp(np)(_aliases.full_like)
|
||||
linspace = get_xp(np)(_aliases.linspace)
|
||||
ones = get_xp(np)(_aliases.ones)
|
||||
ones_like = get_xp(np)(_aliases.ones_like)
|
||||
zeros = get_xp(np)(_aliases.zeros)
|
||||
zeros_like = get_xp(np)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(np)(_aliases.unique_all)
|
||||
unique_counts = get_xp(np)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(np)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(np)(_aliases.unique_values)
|
||||
std = get_xp(np)(_aliases.std)
|
||||
var = get_xp(np)(_aliases.var)
|
||||
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
|
||||
clip = get_xp(np)(_aliases.clip)
|
||||
permute_dims = get_xp(np)(_aliases.permute_dims)
|
||||
reshape = get_xp(np)(_aliases.reshape)
|
||||
argsort = get_xp(np)(_aliases.argsort)
|
||||
sort = get_xp(np)(_aliases.sort)
|
||||
nonzero = get_xp(np)(_aliases.nonzero)
|
||||
ceil = get_xp(np)(_aliases.ceil)
|
||||
floor = get_xp(np)(_aliases.floor)
|
||||
trunc = get_xp(np)(_aliases.trunc)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
sign = get_xp(np)(_aliases.sign)
|
||||
finfo = get_xp(np)(_aliases.finfo)
|
||||
iinfo = get_xp(np)(_aliases.iinfo)
|
||||
|
||||
|
||||
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
|
||||
try:
|
||||
memoryview(obj) # pyright: ignore[reportArgumentType]
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
# asarray() is different enough between numpy, cupy, and dask, the logic
|
||||
# complicated enough that it's easier to define it separately for each module
|
||||
# rather than trying to combine everything into one function in common/
|
||||
def asarray(
|
||||
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: _Copy | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
_helpers._check_device(np, device)
|
||||
|
||||
if copy is None:
|
||||
copy = np._CopyMode.IF_NEEDED
|
||||
elif copy is False:
|
||||
copy = np._CopyMode.NEVER
|
||||
elif copy is True:
|
||||
copy = np._CopyMode.ALWAYS
|
||||
|
||||
return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore
|
||||
|
||||
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: py_bool = True,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
_helpers._check_device(np, device)
|
||||
return x.astype(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
# count_nonzero returns a python int for axis=None and keepdims=False
|
||||
# https://github.com/numpy/numpy/issues/17562
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
keepdims: py_bool = False,
|
||||
) -> Array:
|
||||
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
|
||||
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
|
||||
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
if axis is None and not keepdims:
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
|
||||
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
|
||||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
|
||||
return np.take_along_axis(x, indices, axis=axis)
|
||||
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(np, "vecdot"):
|
||||
vecdot = np.vecdot
|
||||
else:
|
||||
vecdot = get_xp(np)(_aliases.vecdot)
|
||||
|
||||
if hasattr(np, "isdtype"):
|
||||
isdtype = np.isdtype
|
||||
else:
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
|
||||
if hasattr(np, "unstack"):
|
||||
unstack = np.unstack
|
||||
else:
|
||||
unstack = get_xp(np)(_aliases.unstack)
|
||||
|
||||
__all__ = [
|
||||
"__array_namespace_info__",
|
||||
"asarray",
|
||||
"astype",
|
||||
"acos",
|
||||
"acosh",
|
||||
"asin",
|
||||
"asinh",
|
||||
"atan",
|
||||
"atan2",
|
||||
"atanh",
|
||||
"bitwise_left_shift",
|
||||
"bitwise_invert",
|
||||
"bitwise_right_shift",
|
||||
"bool",
|
||||
"concat",
|
||||
"count_nonzero",
|
||||
"pow",
|
||||
"take_along_axis"
|
||||
]
|
||||
__all__ += _aliases.__all__
|
||||
_all_ignore = ["np", "get_xp"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
366
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_info.py
vendored
Normal file
366
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_info.py
vendored
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
complex64,
|
||||
complex128,
|
||||
dtype,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
intp,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ._typing import Device, DType
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for NumPy.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for NumPy.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': numpy.float64,
|
||||
'complex floating': numpy.complex128,
|
||||
'integral': numpy.int64,
|
||||
'indexing': numpy.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = 'numpy'
|
||||
|
||||
def capabilities(self):
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing. Always ``True`` for NumPy.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes. Always ``True`` for
|
||||
NumPy.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self):
|
||||
"""
|
||||
The default device used for new NumPy arrays.
|
||||
|
||||
For NumPy, this always returns ``'cpu'``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new NumPy arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
'cpu'
|
||||
|
||||
"""
|
||||
return "cpu"
|
||||
|
||||
def default_dtypes(
|
||||
self,
|
||||
*,
|
||||
device: Device | None = None,
|
||||
) -> dict[str, dtype[intp | float64 | complex128]]:
|
||||
"""
|
||||
The default data types used for new NumPy arrays.
|
||||
|
||||
For NumPy, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``numpy.float64``
|
||||
- **"complex floating"**: ``numpy.complex128``
|
||||
- **"integral"**: ``numpy.intp``
|
||||
- **"indexing"**: ``numpy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for. For NumPy, only
|
||||
``'cpu'`` is allowed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new NumPy
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': numpy.float64,
|
||||
'complex floating': numpy.complex128,
|
||||
'integral': numpy.int64,
|
||||
'indexing': numpy.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(
|
||||
'Device not understood. Only "cpu" is allowed, but received:'
|
||||
f' {device}'
|
||||
)
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
def dtypes(
|
||||
self,
|
||||
*,
|
||||
device: Device | None = None,
|
||||
kind: str | tuple[str, ...] | None = None,
|
||||
) -> dict[str, DType]:
|
||||
"""
|
||||
The array API data types supported by NumPy.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for. For NumPy, only ``'cpu'`` is
|
||||
allowed.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
NumPy data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': numpy.int8,
|
||||
'int16': numpy.int16,
|
||||
'int32': numpy.int32,
|
||||
'int64': numpy.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(
|
||||
'Device not understood. Only "cpu" is allowed, but received:'
|
||||
f' {device}'
|
||||
)
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": dtype(bool)}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res: dict[str, DType] = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self) -> list[Device]:
|
||||
"""
|
||||
The devices supported by NumPy.
|
||||
|
||||
For NumPy, this always returns ``['cpu']``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by NumPy.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.devices()
|
||||
['cpu']
|
||||
|
||||
"""
|
||||
return ["cpu"]
|
||||
|
||||
|
||||
__all__ = ["__array_namespace_info__"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
30
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_typing.py
vendored
Normal file
30
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/_typing.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
|
||||
Device: TypeAlias = Literal["cpu"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
|
||||
DType: TypeAlias = np.dtype[
|
||||
np.bool_
|
||||
| np.integer[Any]
|
||||
| np.float32
|
||||
| np.float64
|
||||
| np.complex64
|
||||
| np.complex128
|
||||
]
|
||||
Array: TypeAlias = np.ndarray[Any, DType]
|
||||
else:
|
||||
DType: TypeAlias = np.dtype
|
||||
Array: TypeAlias = np.ndarray
|
||||
|
||||
__all__ = ["Array", "DType", "Device"]
|
||||
_all_ignore = ["np"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
35
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/fft.py
vendored
Normal file
35
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/fft.py
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
from numpy.fft import __all__ as fft_all
|
||||
from numpy.fft import fft2, ifft2, irfft2, rfft2
|
||||
|
||||
from .._internal import get_xp
|
||||
from ..common import _fft
|
||||
|
||||
fft = get_xp(np)(_fft.fft)
|
||||
ifft = get_xp(np)(_fft.ifft)
|
||||
fftn = get_xp(np)(_fft.fftn)
|
||||
ifftn = get_xp(np)(_fft.ifftn)
|
||||
rfft = get_xp(np)(_fft.rfft)
|
||||
irfft = get_xp(np)(_fft.irfft)
|
||||
rfftn = get_xp(np)(_fft.rfftn)
|
||||
irfftn = get_xp(np)(_fft.irfftn)
|
||||
hfft = get_xp(np)(_fft.hfft)
|
||||
ihfft = get_xp(np)(_fft.ihfft)
|
||||
fftfreq = get_xp(np)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(np)(_fft.rfftfreq)
|
||||
fftshift = get_xp(np)(_fft.fftshift)
|
||||
ifftshift = get_xp(np)(_fft.ifftshift)
|
||||
|
||||
|
||||
__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
|
||||
__all__ += _fft.__all__
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
|
||||
|
||||
del get_xp
|
||||
del np
|
||||
del fft_all
|
||||
del _fft
|
||||
143
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/linalg.py
vendored
Normal file
143
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/numpy/linalg.py
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
|
||||
from numpy.linalg import (
|
||||
LinAlgError,
|
||||
cond,
|
||||
det,
|
||||
eig,
|
||||
eigvals,
|
||||
eigvalsh,
|
||||
inv,
|
||||
lstsq,
|
||||
matrix_power,
|
||||
multi_dot,
|
||||
norm,
|
||||
tensorinv,
|
||||
tensorsolve,
|
||||
)
|
||||
|
||||
from .._internal import get_xp
|
||||
from ..common import _linalg
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
|
||||
from ._typing import Array
|
||||
|
||||
cross = get_xp(np)(_linalg.cross)
|
||||
outer = get_xp(np)(_linalg.outer)
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
eigh = get_xp(np)(_linalg.eigh)
|
||||
qr = get_xp(np)(_linalg.qr)
|
||||
slogdet = get_xp(np)(_linalg.slogdet)
|
||||
svd = get_xp(np)(_linalg.svd)
|
||||
cholesky = get_xp(np)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(np)(_linalg.matrix_rank)
|
||||
pinv = get_xp(np)(_linalg.pinv)
|
||||
matrix_norm = get_xp(np)(_linalg.matrix_norm)
|
||||
svdvals = get_xp(np)(_linalg.svdvals)
|
||||
diagonal = get_xp(np)(_linalg.diagonal)
|
||||
trace = get_xp(np)(_linalg.trace)
|
||||
|
||||
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
|
||||
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
|
||||
# of matrices. The np.linalg.solve behavior of allowing stacks of both
|
||||
# matrices and vectors is ambiguous c.f.
|
||||
# https://github.com/numpy/numpy/issues/15349 and
|
||||
# https://github.com/data-apis/array-api/issues/285.
|
||||
|
||||
# To workaround this, the below is the code from np.linalg.solve except
|
||||
# only calling solve1 in the exactly 1D case.
|
||||
|
||||
|
||||
# This code is here instead of in common because it is numpy specific. Also
|
||||
# note that CuPy's solve() does not currently support broadcasting (see
|
||||
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
|
||||
def solve(x1: Array, x2: Array, /) -> Array:
|
||||
try:
|
||||
from numpy.linalg._linalg import (
|
||||
_assert_stacked_2d,
|
||||
_assert_stacked_square,
|
||||
_commonType,
|
||||
_makearray,
|
||||
_raise_linalgerror_singular,
|
||||
isComplexType,
|
||||
)
|
||||
except ImportError:
|
||||
from numpy.linalg.linalg import (
|
||||
_assert_stacked_2d,
|
||||
_assert_stacked_square,
|
||||
_commonType,
|
||||
_makearray,
|
||||
_raise_linalgerror_singular,
|
||||
isComplexType,
|
||||
)
|
||||
from numpy.linalg import _umath_linalg
|
||||
|
||||
x1, _ = _makearray(x1)
|
||||
_assert_stacked_2d(x1)
|
||||
_assert_stacked_square(x1)
|
||||
x2, wrap = _makearray(x2)
|
||||
t, result_t = _commonType(x1, x2)
|
||||
|
||||
# This part is different from np.linalg.solve
|
||||
gufunc: np.ufunc
|
||||
if x2.ndim == 1:
|
||||
gufunc = _umath_linalg.solve1
|
||||
else:
|
||||
gufunc = _umath_linalg.solve
|
||||
|
||||
# This does nothing currently but is left in because it will be relevant
|
||||
# when complex dtype support is added to the spec in 2022.
|
||||
signature = "DD->D" if isComplexType(t) else "dd->d"
|
||||
with np.errstate(
|
||||
call=_raise_linalgerror_singular,
|
||||
invalid="call",
|
||||
over="ignore",
|
||||
divide="ignore",
|
||||
under="ignore",
|
||||
):
|
||||
r: Array = gufunc(x1, x2, signature=signature)
|
||||
|
||||
return wrap(r.astype(result_t, copy=False))
|
||||
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(np.linalg, "vector_norm"):
|
||||
vector_norm = np.linalg.vector_norm
|
||||
else:
|
||||
vector_norm = get_xp(np)(_linalg.vector_norm)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LinAlgError",
|
||||
"cond",
|
||||
"det",
|
||||
"eig",
|
||||
"eigvals",
|
||||
"eigvalsh",
|
||||
"inv",
|
||||
"lstsq",
|
||||
"matrix_power",
|
||||
"multi_dot",
|
||||
"norm",
|
||||
"tensorinv",
|
||||
"tensorsolve",
|
||||
]
|
||||
__all__ += _linalg.__all__
|
||||
__all__ += ["solve", "vector_norm"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
0
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/py.typed
vendored
Normal file
0
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/py.typed
vendored
Normal file
22
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/__init__.py
vendored
Normal file
22
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/__init__.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
from torch import * # noqa: F403
|
||||
|
||||
# Several names are not included in the above import *
|
||||
import torch
|
||||
for n in dir(torch):
|
||||
if (n.startswith('_')
|
||||
or n.endswith('_')
|
||||
or 'cuda' in n
|
||||
or 'cpu' in n
|
||||
or 'backward' in n):
|
||||
continue
|
||||
exec(f"{n} = torch.{n}")
|
||||
del n
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
__array_api_version__ = '2024.12'
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
855
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_aliases.py
vendored
Normal file
855
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_aliases.py
vendored
Normal file
@@ -0,0 +1,855 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce as _reduce, wraps as _wraps
|
||||
from builtins import all as _builtin_all, any as _builtin_any
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from .._internal import get_xp
|
||||
from ..common import _aliases
|
||||
from ..common._typing import NestedSequence, SupportsBufferProtocol
|
||||
from ._info import __array_namespace_info__
|
||||
from ._typing import Array, Device, DType
|
||||
|
||||
_int_dtypes = {
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
}
|
||||
try:
|
||||
# torch >=2.3
|
||||
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
_array_api_dtypes = {
|
||||
torch.bool,
|
||||
*_int_dtypes,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.complex64,
|
||||
torch.complex128,
|
||||
}
|
||||
|
||||
_promotion_table = {
|
||||
# ints
|
||||
(torch.int8, torch.int16): torch.int16,
|
||||
(torch.int8, torch.int32): torch.int32,
|
||||
(torch.int8, torch.int64): torch.int64,
|
||||
(torch.int16, torch.int32): torch.int32,
|
||||
(torch.int16, torch.int64): torch.int64,
|
||||
(torch.int32, torch.int64): torch.int64,
|
||||
# ints and uints (mixed sign)
|
||||
(torch.uint8, torch.int8): torch.int16,
|
||||
(torch.uint8, torch.int16): torch.int16,
|
||||
(torch.uint8, torch.int32): torch.int32,
|
||||
(torch.uint8, torch.int64): torch.int64,
|
||||
# floats
|
||||
(torch.float32, torch.float64): torch.float64,
|
||||
# complexes
|
||||
(torch.complex64, torch.complex128): torch.complex128,
|
||||
# Mixed float and complex
|
||||
(torch.float32, torch.complex64): torch.complex64,
|
||||
(torch.float32, torch.complex128): torch.complex128,
|
||||
(torch.float64, torch.complex64): torch.complex128,
|
||||
(torch.float64, torch.complex128): torch.complex128,
|
||||
}
|
||||
|
||||
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
|
||||
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
|
||||
|
||||
|
||||
def _two_arg(f):
|
||||
@_wraps(f)
|
||||
def _f(x1, x2, /, **kwargs):
|
||||
x1, x2 = _fix_promotion(x1, x2)
|
||||
return f(x1, x2, **kwargs)
|
||||
if _f.__doc__ is None:
|
||||
_f.__doc__ = f"""\
|
||||
Array API compatibility wrapper for torch.{f.__name__}.
|
||||
|
||||
See the corresponding PyTorch documentation and/or the array API specification
|
||||
for more details.
|
||||
|
||||
"""
|
||||
return _f
|
||||
|
||||
def _fix_promotion(x1, x2, only_scalar=True):
|
||||
if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor):
|
||||
return x1, x2
|
||||
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
|
||||
return x1, x2
|
||||
# If an argument is 0-D pytorch downcasts the other argument
|
||||
if not only_scalar or x1.shape == ():
|
||||
dtype = result_type(x1, x2)
|
||||
x2 = x2.to(dtype)
|
||||
if not only_scalar or x2.shape == ():
|
||||
dtype = result_type(x1, x2)
|
||||
x1 = x1.to(dtype)
|
||||
return x1, x2
|
||||
|
||||
|
||||
_py_scalars = (bool, int, float, complex)
|
||||
|
||||
|
||||
def result_type(
|
||||
*arrays_and_dtypes: Array | DType | bool | int | float | complex
|
||||
) -> DType:
|
||||
num = len(arrays_and_dtypes)
|
||||
|
||||
if num == 0:
|
||||
raise ValueError("At least one array or dtype must be provided")
|
||||
|
||||
elif num == 1:
|
||||
x = arrays_and_dtypes[0]
|
||||
if isinstance(x, torch.dtype):
|
||||
return x
|
||||
return x.dtype
|
||||
|
||||
if num == 2:
|
||||
x, y = arrays_and_dtypes
|
||||
return _result_type(x, y)
|
||||
|
||||
else:
|
||||
# sort scalars so that they are treated last
|
||||
scalars, others = [], []
|
||||
for x in arrays_and_dtypes:
|
||||
if isinstance(x, _py_scalars):
|
||||
scalars.append(x)
|
||||
else:
|
||||
others.append(x)
|
||||
if not others:
|
||||
raise ValueError("At least one array or dtype must be provided")
|
||||
|
||||
# combine left-to-right
|
||||
return _reduce(_result_type, others + scalars)
|
||||
|
||||
|
||||
def _result_type(
|
||||
x: Array | DType | bool | int | float | complex,
|
||||
y: Array | DType | bool | int | float | complex,
|
||||
) -> DType:
|
||||
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
|
||||
xdt = x if isinstance(x, torch.dtype) else x.dtype
|
||||
ydt = y if isinstance(y, torch.dtype) else y.dtype
|
||||
|
||||
try:
|
||||
return _promotion_table[xdt, ydt]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# This doesn't result_type(dtype, dtype) for non-array API dtypes
|
||||
# because torch.result_type only accepts tensors. This does however, allow
|
||||
# cross-kind promotion.
|
||||
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
|
||||
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
|
||||
return torch.result_type(x, y)
|
||||
|
||||
|
||||
def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
|
||||
if not isinstance(from_, torch.dtype):
|
||||
from_ = from_.dtype
|
||||
return torch.can_cast(from_, to)
|
||||
|
||||
# Basic renames
|
||||
bitwise_invert = torch.bitwise_not
|
||||
newaxis = None
|
||||
# torch.conj sets the conjugation bit, which breaks conversion to other
|
||||
# libraries. See https://github.com/data-apis/array-api-compat/issues/173
|
||||
conj = torch.conj_physical
|
||||
|
||||
# Two-arg elementwise functions
|
||||
# These require a wrapper to do the correct type promotion on 0-D tensors
|
||||
add = _two_arg(torch.add)
|
||||
atan2 = _two_arg(torch.atan2)
|
||||
bitwise_and = _two_arg(torch.bitwise_and)
|
||||
bitwise_left_shift = _two_arg(torch.bitwise_left_shift)
|
||||
bitwise_or = _two_arg(torch.bitwise_or)
|
||||
bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
|
||||
bitwise_xor = _two_arg(torch.bitwise_xor)
|
||||
copysign = _two_arg(torch.copysign)
|
||||
divide = _two_arg(torch.divide)
|
||||
# Also a rename. torch.equal does not broadcast
|
||||
equal = _two_arg(torch.eq)
|
||||
floor_divide = _two_arg(torch.floor_divide)
|
||||
greater = _two_arg(torch.greater)
|
||||
greater_equal = _two_arg(torch.greater_equal)
|
||||
hypot = _two_arg(torch.hypot)
|
||||
less = _two_arg(torch.less)
|
||||
less_equal = _two_arg(torch.less_equal)
|
||||
logaddexp = _two_arg(torch.logaddexp)
|
||||
# logical functions are not included here because they only accept bool in the
|
||||
# spec, so type promotion is irrelevant.
|
||||
maximum = _two_arg(torch.maximum)
|
||||
minimum = _two_arg(torch.minimum)
|
||||
multiply = _two_arg(torch.multiply)
|
||||
not_equal = _two_arg(torch.not_equal)
|
||||
pow = _two_arg(torch.pow)
|
||||
remainder = _two_arg(torch.remainder)
|
||||
subtract = _two_arg(torch.subtract)
|
||||
|
||||
|
||||
def asarray(
|
||||
obj: (
|
||||
Array
|
||||
| bool | int | float | complex
|
||||
| NestedSequence[bool | int | float | complex]
|
||||
| SupportsBufferProtocol
|
||||
),
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Array:
|
||||
# torch.asarray does not respect input->output device propagation
|
||||
# https://github.com/pytorch/pytorch/issues/150199
|
||||
if device is None and isinstance(obj, torch.Tensor):
|
||||
device = obj.device
|
||||
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
|
||||
|
||||
|
||||
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
|
||||
# of 'axis'.
|
||||
|
||||
# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
|
||||
def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
return torch.amax(x, axis, keepdims=keepdims)
|
||||
|
||||
def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
return torch.amin(x, axis, keepdims=keepdims)
|
||||
|
||||
clip = get_xp(torch)(_aliases.clip)
|
||||
unstack = get_xp(torch)(_aliases.unstack)
|
||||
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
|
||||
finfo = get_xp(torch)(_aliases.finfo)
|
||||
iinfo = get_xp(torch)(_aliases.iinfo)
|
||||
|
||||
|
||||
# torch.sort also returns a tuple
|
||||
# https://github.com/pytorch/pytorch/issues/70921
|
||||
def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array:
|
||||
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
|
||||
|
||||
def _normalize_axes(axis, ndim):
|
||||
axes = []
|
||||
if ndim == 0 and axis:
|
||||
# Better error message in this case
|
||||
raise IndexError(f"Dimension out of range: {axis[0]}")
|
||||
lower, upper = -ndim, ndim - 1
|
||||
for a in axis:
|
||||
if a < lower or a > upper:
|
||||
# Match torch error message (e.g., from sum())
|
||||
raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
|
||||
if a < 0:
|
||||
a = a + ndim
|
||||
if a in axes:
|
||||
# Use IndexError instead of RuntimeError, and "axis" instead of "dim"
|
||||
raise IndexError(f"Axis {a} appears multiple times in the list of axes")
|
||||
axes.append(a)
|
||||
return sorted(axes)
|
||||
|
||||
def _axis_none_keepdims(x, ndim, keepdims):
|
||||
# Apply keepdims when axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
# Note that this is only valid for the axis=None case.
|
||||
if keepdims:
|
||||
for i in range(ndim):
|
||||
x = torch.unsqueeze(x, 0)
|
||||
return x
|
||||
|
||||
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
|
||||
# Some reductions don't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
axes = _normalize_axes(axis, x.ndim)
|
||||
for a in reversed(axes):
|
||||
x = torch.movedim(x, a, -1)
|
||||
x = torch.flatten(x, -len(axes))
|
||||
|
||||
out = f(x, -1, **kwargs)
|
||||
|
||||
if keepdims:
|
||||
for a in axes:
|
||||
out = torch.unsqueeze(out, a)
|
||||
return out
|
||||
|
||||
|
||||
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
|
||||
"""
|
||||
Implements `sum(..., axis=())` and `prod(..., axis=())`.
|
||||
|
||||
Works around https://github.com/pytorch/pytorch/issues/29137
|
||||
"""
|
||||
if dtype is not None:
|
||||
return x.clone() if dtype == x.dtype else x.to(dtype)
|
||||
|
||||
# We can't upcast uint8 according to the spec because there is no
|
||||
# torch.uint64, so at least upcast to int64 which is what prod does
|
||||
# when axis=None.
|
||||
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
|
||||
return x.to(torch.int64)
|
||||
|
||||
return x.clone()
|
||||
|
||||
|
||||
def prod(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[DType] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
|
||||
if axis == ():
|
||||
return _sum_prod_no_axis(x, dtype)
|
||||
# torch.prod doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.prod(x, dtype=dtype, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
|
||||
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
def sum(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[DType] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
|
||||
if axis == ():
|
||||
return _sum_prod_no_axis(x, dtype)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.sum(x, dtype=dtype, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
|
||||
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
||||
|
||||
def any(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
|
||||
if axis == ():
|
||||
return x.to(torch.bool)
|
||||
# torch.any doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs)
|
||||
return res.to(torch.bool)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.any(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res.to(torch.bool)
|
||||
|
||||
# torch.any doesn't return bool for uint8
|
||||
return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
|
||||
|
||||
def all(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
|
||||
if axis == ():
|
||||
return x.to(torch.bool)
|
||||
# torch.all doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs)
|
||||
return res.to(torch.bool)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.all(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res.to(torch.bool)
|
||||
|
||||
# torch.all doesn't return bool for uint8
|
||||
return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
|
||||
|
||||
def mean(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.mean(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
|
||||
|
||||
def std(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
# Note, float correction is not supported
|
||||
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
|
||||
# implement it here for now.
|
||||
|
||||
if isinstance(correction, float):
|
||||
_correction = int(correction)
|
||||
if correction != _correction:
|
||||
raise NotImplementedError("float correction in torch std() is not yet supported")
|
||||
else:
|
||||
_correction = correction
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.zeros_like(x)
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
def var(x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> Array:
|
||||
# Note, float correction is not supported
|
||||
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
|
||||
# implement it here for now.
|
||||
|
||||
# if isinstance(correction, float):
|
||||
# correction = int(correction)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.zeros_like(x)
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
# torch.concat doesn't support dim=None
|
||||
# https://github.com/pytorch/pytorch/issues/70925
|
||||
def concat(arrays: Union[Tuple[Array, ...], List[Array]],
|
||||
/,
|
||||
*,
|
||||
axis: Optional[int] = 0,
|
||||
**kwargs) -> Array:
|
||||
if axis is None:
|
||||
arrays = tuple(ar.flatten() for ar in arrays)
|
||||
axis = 0
|
||||
return torch.concat(arrays, axis, **kwargs)
|
||||
|
||||
# torch.squeeze only accepts int dim and doesn't require it
|
||||
# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
|
||||
# added at https://github.com/pytorch/pytorch/pull/89017.
|
||||
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
for a in axis:
|
||||
if x.shape[a] != 1:
|
||||
raise ValueError("squeezed dimensions must be equal to 1")
|
||||
axes = _normalize_axes(axis, x.ndim)
|
||||
# Remove this once pytorch 1.14 is released with the above PR #89017.
|
||||
sequence = [a - i for i, a in enumerate(axes)]
|
||||
for a in sequence:
|
||||
x = torch.squeeze(x, a)
|
||||
return x
|
||||
|
||||
# torch.broadcast_to uses size instead of shape
|
||||
def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array:
|
||||
return torch.broadcast_to(x, shape, **kwargs)
|
||||
|
||||
# torch.permute uses dims instead of axes
|
||||
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
|
||||
return torch.permute(x, axes)
|
||||
|
||||
# The axis parameter doesn't work for flip() and roll()
|
||||
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
|
||||
# accept axis=None
|
||||
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
|
||||
if axis is None:
|
||||
axis = tuple(range(x.ndim))
|
||||
# torch.flip doesn't accept dim as an int but the method does
|
||||
# https://github.com/pytorch/pytorch/issues/18095
|
||||
return x.flip(axis, **kwargs)
|
||||
|
||||
def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
|
||||
return torch.roll(x, shift, axis, **kwargs)
|
||||
|
||||
def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
||||
return torch.nonzero(x, as_tuple=True, **kwargs)
|
||||
|
||||
|
||||
# torch uses `dim` instead of `axis`
|
||||
def diff(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: int = -1,
|
||||
n: int = 1,
|
||||
prepend: Optional[Array] = None,
|
||||
append: Optional[Array] = None,
|
||||
) -> Array:
|
||||
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
|
||||
|
||||
|
||||
# torch uses `dim` instead of `axis`, does not have keepdims
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
) -> Array:
|
||||
result = torch.count_nonzero(x, dim=axis)
|
||||
if keepdims:
|
||||
if isinstance(axis, int):
|
||||
return result.unsqueeze(axis)
|
||||
elif isinstance(axis, tuple):
|
||||
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
|
||||
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
|
||||
return torch.reshape(result, sh)
|
||||
return _axis_none_keepdims(result, x.ndim, keepdims)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
# "repeat" is torch.repeat_interleave; also the dim argument
|
||||
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
|
||||
return torch.repeat_interleave(x, repeats, axis)
|
||||
|
||||
|
||||
def where(
|
||||
condition: Array,
|
||||
x1: Array | bool | int | float | complex,
|
||||
x2: Array | bool | int | float | complex,
|
||||
/,
|
||||
) -> Array:
|
||||
x1, x2 = _fix_promotion(x1, x2)
|
||||
return torch.where(condition, x1, x2)
|
||||
|
||||
|
||||
# torch.reshape doesn't have the copy keyword
|
||||
def reshape(x: Array,
|
||||
/,
|
||||
shape: Tuple[int, ...],
|
||||
*,
|
||||
copy: Optional[bool] = None,
|
||||
**kwargs) -> Array:
|
||||
if copy is not None:
|
||||
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
|
||||
return torch.reshape(x, shape, **kwargs)
|
||||
|
||||
# torch.arange doesn't support returning empty arrays
|
||||
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
|
||||
# keyword argument combinations
|
||||
# (https://github.com/pytorch/pytorch/issues/70914)
|
||||
def arange(start: Union[int, float],
|
||||
/,
|
||||
stop: Optional[Union[int, float]] = None,
|
||||
step: Union[int, float] = 1,
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if step > 0 and stop <= start or step < 0 and stop >= start:
|
||||
if dtype is None:
|
||||
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
|
||||
dtype = torch.int64
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return torch.empty(0, dtype=dtype, device=device, **kwargs)
|
||||
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# torch.eye does not accept None as a default for the second argument and
|
||||
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
|
||||
def eye(n_rows: int,
|
||||
n_cols: Optional[int] = None,
|
||||
/,
|
||||
*,
|
||||
k: int = 0,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
if n_cols is None:
|
||||
n_cols = n_rows
|
||||
z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
|
||||
if abs(k) <= n_rows + n_cols:
|
||||
z.diagonal(k).fill_(1)
|
||||
return z
|
||||
|
||||
# torch.linspace doesn't have the endpoint parameter
|
||||
def linspace(start: Union[int, float],
|
||||
stop: Union[int, float],
|
||||
/,
|
||||
num: int,
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
endpoint: bool = True,
|
||||
**kwargs) -> Array:
|
||||
if not endpoint:
|
||||
return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
|
||||
return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# torch.full does not accept an int size
|
||||
# https://github.com/pytorch/pytorch/issues/70906
|
||||
def full(shape: Union[int, Tuple[int, ...]],
|
||||
fill_value: bool | int | float | complex,
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
|
||||
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# ones, zeros, and empty do not accept shape as a keyword argument
|
||||
def ones(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
def zeros(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
def empty(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> Array:
|
||||
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# tril and triu do not call the keyword argument k
|
||||
|
||||
def tril(x: Array, /, *, k: int = 0) -> Array:
|
||||
return torch.tril(x, k)
|
||||
|
||||
def triu(x: Array, /, *, k: int = 0) -> Array:
|
||||
return torch.triu(x, k)
|
||||
|
||||
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
|
||||
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
|
||||
return torch.unsqueeze(x, axis)
|
||||
|
||||
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: bool = True,
|
||||
device: Optional[Device] = None,
|
||||
) -> Array:
|
||||
if device is not None:
|
||||
return x.to(device, dtype=dtype, copy=copy)
|
||||
return x.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
def broadcast_arrays(*arrays: Array) -> List[Array]:
|
||||
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
|
||||
return [torch.broadcast_to(a, shape) for a in arrays]
|
||||
|
||||
# Note that these named tuples aren't actually part of the standard namespace,
|
||||
# but I don't see any issue with exporting the names here regardless.
|
||||
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
|
||||
UniqueInverseResult)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/70920
|
||||
def unique_all(x: Array) -> UniqueAllResult:
|
||||
# torch.unique doesn't support returning indices.
|
||||
# https://github.com/pytorch/pytorch/issues/36748. The workaround
|
||||
# suggested in that issue doesn't actually function correctly (it relies
|
||||
# on non-deterministic behavior of scatter()).
|
||||
raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)")
|
||||
|
||||
# values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True)
|
||||
# # torch.unique incorrectly gives a 0 count for nan values.
|
||||
# # https://github.com/pytorch/pytorch/issues/94106
|
||||
# counts[torch.isnan(values)] = 1
|
||||
# return UniqueAllResult(values, indices, inverse_indices, counts)
|
||||
|
||||
def unique_counts(x: Array) -> UniqueCountsResult:
|
||||
values, counts = torch.unique(x, return_counts=True)
|
||||
|
||||
# torch.unique incorrectly gives a 0 count for nan values.
|
||||
# https://github.com/pytorch/pytorch/issues/94106
|
||||
counts[torch.isnan(values)] = 1
|
||||
return UniqueCountsResult(values, counts)
|
||||
|
||||
def unique_inverse(x: Array) -> UniqueInverseResult:
|
||||
values, inverse = torch.unique(x, return_inverse=True)
|
||||
return UniqueInverseResult(values, inverse)
|
||||
|
||||
def unique_values(x: Array) -> Array:
|
||||
return torch.unique(x)
|
||||
|
||||
def matmul(x1: Array, x2: Array, /, **kwargs) -> Array:
|
||||
# torch.matmul doesn't type promote (but differently from _fix_promotion)
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return torch.matmul(x1, x2, **kwargs)
|
||||
|
||||
matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
|
||||
_vecdot = get_xp(torch)(_aliases.vecdot)
|
||||
|
||||
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return _vecdot(x1, x2, axis=axis)
|
||||
|
||||
# torch.tensordot uses dims instead of axes
|
||||
def tensordot(
|
||||
x1: Array,
|
||||
x2: Array,
|
||||
/,
|
||||
*,
|
||||
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
# Note: torch.tensordot fails with integer dtypes when there is only 1
|
||||
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return torch.tensordot(x1, x2, dims=axes, **kwargs)
|
||||
|
||||
|
||||
def isdtype(
|
||||
dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]],
|
||||
*, _tuple=True, # Disallow nested tuples
|
||||
) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
|
||||
|
||||
Note that outside of this function, this compat library does not yet fully
|
||||
support complex numbers.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
|
||||
for more details
|
||||
"""
|
||||
if isinstance(kind, tuple) and _tuple:
|
||||
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
|
||||
elif isinstance(kind, str):
|
||||
if kind == 'bool':
|
||||
return dtype == torch.bool
|
||||
elif kind == 'signed integer':
|
||||
return dtype in _int_dtypes and dtype.is_signed
|
||||
elif kind == 'unsigned integer':
|
||||
return dtype in _int_dtypes and not dtype.is_signed
|
||||
elif kind == 'integral':
|
||||
return dtype in _int_dtypes
|
||||
elif kind == 'real floating':
|
||||
return dtype.is_floating_point
|
||||
elif kind == 'complex floating':
|
||||
return dtype.is_complex
|
||||
elif kind == 'numeric':
|
||||
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data type kind: {kind!r}")
|
||||
else:
|
||||
return dtype == kind
|
||||
|
||||
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array:
|
||||
if axis is None:
|
||||
if x.ndim != 1:
|
||||
raise ValueError("axis must be specified when ndim > 1")
|
||||
axis = 0
|
||||
return torch.index_select(x, axis, indices, **kwargs)
|
||||
|
||||
|
||||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
|
||||
return torch.take_along_dim(x, indices, dim=axis)
|
||||
|
||||
|
||||
def sign(x: Array, /) -> Array:
|
||||
# torch sign() does not support complex numbers and does not propagate
|
||||
# nans. See https://github.com/data-apis/array-api-compat/issues/136
|
||||
if x.dtype.is_complex:
|
||||
out = x/torch.abs(x)
|
||||
# sign(0) = 0 but the above formula would give nan
|
||||
out[x == 0+0j] = 0+0j
|
||||
return out
|
||||
else:
|
||||
out = torch.sign(x)
|
||||
if x.dtype.is_floating_point:
|
||||
out[torch.isnan(x)] = torch.nan
|
||||
return out
|
||||
|
||||
|
||||
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
|
||||
# enforce the default of 'xy'
|
||||
# TODO: is the return type a list or a tuple
|
||||
return list(torch.meshgrid(*arrays, indexing='xy'))
|
||||
|
||||
|
||||
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
|
||||
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
|
||||
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
|
||||
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
|
||||
'diff', 'divide',
|
||||
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
|
||||
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
|
||||
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
|
||||
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
|
||||
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
|
||||
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
|
||||
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
|
||||
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
|
||||
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
||||
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
||||
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
|
||||
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']
|
||||
|
||||
_all_ignore = ['torch', 'get_xp']
|
||||
369
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_info.py
vendored
Normal file
369
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_info.py
vendored
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
import torch
|
||||
|
||||
from functools import cache
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for PyTorch.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for PyTorch.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': numpy.float64,
|
||||
'complex floating': numpy.complex128,
|
||||
'integral': numpy.int64,
|
||||
'indexing': numpy.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = 'torch'
|
||||
|
||||
def capabilities(self):
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing. Always ``True`` for PyTorch.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes. Always ``True`` for
|
||||
PyTorch.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self):
|
||||
"""
|
||||
The default device used for new PyTorch arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new PyTorch arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
device(type='cpu')
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method returns the static default device when PyTorch is initialized.
|
||||
However, the *current* device used by creation functions (``empty`` etc.)
|
||||
can be changed at runtime.
|
||||
|
||||
See Also
|
||||
--------
|
||||
https://github.com/data-apis/array-api/issues/835
|
||||
"""
|
||||
return torch.device("cpu")
|
||||
|
||||
def default_dtypes(self, *, device=None):
|
||||
"""
|
||||
The default data types used for new PyTorch arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : Device, optional
|
||||
The device to get the default data types for.
|
||||
Unused for PyTorch, as all devices use the same default dtypes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new PyTorch
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': torch.float32,
|
||||
'complex floating': torch.complex64,
|
||||
'integral': torch.int64,
|
||||
'indexing': torch.int64}
|
||||
|
||||
"""
|
||||
# Note: if the default is set to float64, the devices like MPS that
|
||||
# don't support float64 will error. We still return the default_dtype
|
||||
# value here because this error doesn't represent a different default
|
||||
# per-device.
|
||||
default_floating = torch.get_default_dtype()
|
||||
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128
|
||||
default_integral = torch.int64
|
||||
return {
|
||||
"real floating": default_floating,
|
||||
"complex floating": default_complex,
|
||||
"integral": default_integral,
|
||||
"indexing": default_integral,
|
||||
}
|
||||
|
||||
|
||||
def _dtypes(self, kind):
|
||||
bool = torch.bool
|
||||
int8 = torch.int8
|
||||
int16 = torch.int16
|
||||
int32 = torch.int32
|
||||
int64 = torch.int64
|
||||
uint8 = torch.uint8
|
||||
# uint16, uint32, and uint64 are present in newer versions of pytorch,
|
||||
# but they aren't generally supported by the array API functions, so
|
||||
# we omit them from this function.
|
||||
float32 = torch.float32
|
||||
float64 = torch.float64
|
||||
complex64 = torch.complex64
|
||||
complex128 = torch.complex128
|
||||
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": bool,
|
||||
"int8": int8,
|
||||
"int16": int16,
|
||||
"int32": int32,
|
||||
"int64": int64,
|
||||
"uint8": uint8,
|
||||
"float32": float32,
|
||||
"float64": float64,
|
||||
"complex64": complex64,
|
||||
"complex128": complex128,
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": bool}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": int8,
|
||||
"int16": int16,
|
||||
"int32": int32,
|
||||
"int64": int64,
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": uint8,
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": int8,
|
||||
"int16": int16,
|
||||
"int32": int32,
|
||||
"int64": int64,
|
||||
"uint8": uint8,
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": float32,
|
||||
"float64": float64,
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": complex64,
|
||||
"complex128": complex128,
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": int8,
|
||||
"int16": int16,
|
||||
"int32": int32,
|
||||
"int64": int64,
|
||||
"uint8": uint8,
|
||||
"float32": float32,
|
||||
"float64": float64,
|
||||
"complex64": complex64,
|
||||
"complex128": complex128,
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
@cache
|
||||
def dtypes(self, *, device=None, kind=None):
|
||||
"""
|
||||
The array API data types supported by PyTorch.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : Device, optional
|
||||
The device to get the data types for.
|
||||
Unused for PyTorch, as all devices use the same dtypes.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
PyTorch data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': numpy.int8,
|
||||
'int16': numpy.int16,
|
||||
'int32': numpy.int32,
|
||||
'int64': numpy.int64}
|
||||
|
||||
"""
|
||||
res = self._dtypes(kind)
|
||||
for k, v in res.copy().items():
|
||||
try:
|
||||
torch.empty((0,), dtype=v, device=device)
|
||||
except:
|
||||
del res[k]
|
||||
return res
|
||||
|
||||
@cache
|
||||
def devices(self):
|
||||
"""
|
||||
The devices supported by PyTorch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by PyTorch.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.devices()
|
||||
[device(type='cpu'), device(type='mps', index=0), device(type='meta')]
|
||||
|
||||
"""
|
||||
# Torch doesn't have a straightforward way to get the list of all
|
||||
# currently supported devices. To do this, we first parse the error
|
||||
# message of torch.device to get the list of all possible types of
|
||||
# device:
|
||||
try:
|
||||
torch.device('notadevice')
|
||||
raise AssertionError("unreachable") # pragma: nocover
|
||||
except RuntimeError as e:
|
||||
# The error message is something like:
|
||||
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
|
||||
devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ')
|
||||
|
||||
# Next we need to check for different indices for different devices.
|
||||
# device(device_name, index=index) doesn't actually check if the
|
||||
# device name or index is valid. We have to try to create a tensor
|
||||
# with it (which is why this function is cached).
|
||||
devices = []
|
||||
for device_name in devices_names:
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
a = torch.empty((0,), device=torch.device(device_name, index=i))
|
||||
if a.device in devices:
|
||||
break
|
||||
devices.append(a.device)
|
||||
except:
|
||||
break
|
||||
i += 1
|
||||
|
||||
return devices
|
||||
3
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_typing.py
vendored
Normal file
3
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/_typing.py
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
__all__ = ["Array", "Device", "DType"]
|
||||
|
||||
from torch import device as Device, dtype as DType, Tensor as Array
|
||||
85
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/fft.py
vendored
Normal file
85
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/fft.py
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Sequence, Literal
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
from torch.fft import * # noqa: F403
|
||||
|
||||
from ._typing import Array
|
||||
|
||||
# Several torch fft functions do not map axes to dim
|
||||
|
||||
def fftn(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def ifftn(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def rfftn(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def irfftn(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def fftshift(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axes: Union[int, Sequence[int]] = None,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.fftshift(x, dim=axes, **kwargs)
|
||||
|
||||
def ifftshift(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axes: Union[int, Sequence[int]] = None,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
return torch.fft.ifftshift(x, dim=axes, **kwargs)
|
||||
|
||||
|
||||
__all__ = torch.fft.__all__ + [
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
"fftshift",
|
||||
"ifftshift",
|
||||
]
|
||||
|
||||
_all_ignore = ['torch']
|
||||
121
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/linalg.py
vendored
Normal file
121
venv/lib/python3.12/site-packages/sklearn/externals/array_api_compat/torch/linalg.py
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
from torch.linalg import * # noqa: F403
|
||||
|
||||
# torch.linalg doesn't define __all__
|
||||
# from torch.linalg import __all__ as linalg_all
|
||||
from torch import linalg as torch_linalg
|
||||
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
|
||||
|
||||
# outer is implemented in torch but aren't in the linalg namespace
|
||||
from torch import outer
|
||||
from ._aliases import _fix_promotion, sum
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot
|
||||
from ._typing import Array, DType
|
||||
from ..common._typing import JustInt, JustFloat
|
||||
|
||||
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
|
||||
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
|
||||
|
||||
# torch.cross also does not support broadcasting when it would add new
|
||||
# dimensions https://github.com/pytorch/pytorch/issues/39656
|
||||
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
|
||||
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
|
||||
if not (x1.shape[axis] == x2.shape[axis] == 3):
|
||||
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
|
||||
x1, x2 = torch.broadcast_tensors(x1, x2)
|
||||
return torch_linalg.cross(x1, x2, dim=axis)
|
||||
|
||||
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
|
||||
from ._aliases import isdtype
|
||||
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
|
||||
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
|
||||
if x1.shape[axis] != x2.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along the given axis")
|
||||
|
||||
# torch.linalg.vecdot doesn't support integer dtypes
|
||||
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
|
||||
if kwargs:
|
||||
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
|
||||
|
||||
x1_ = torch.moveaxis(x1, axis, -1)
|
||||
x2_ = torch.moveaxis(x2, axis, -1)
|
||||
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
|
||||
|
||||
res = x1_[..., None, :] @ x2_[..., None]
|
||||
return res[..., 0, 0]
|
||||
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
|
||||
|
||||
def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
|
||||
# whenever
|
||||
# 1. x1.ndim - 1 == x2.ndim
|
||||
# 2. x1.shape[:-1] == x2.shape
|
||||
#
|
||||
# See linalg_solve_is_vector_rhs in
|
||||
# aten/src/ATen/native/LinearAlgebraUtils.h and
|
||||
# TORCH_META_FUNC(_linalg_solve_ex) in
|
||||
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
|
||||
#
|
||||
# The easiest way to work around this is to prepend a size 1 dimension to
|
||||
# x2, since x2 is already one dimension less than x1.
|
||||
#
|
||||
# See https://github.com/pytorch/pytorch/issues/52915
|
||||
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
|
||||
x2 = x2[None]
|
||||
return torch.linalg.solve(x1, x2, **kwargs)
|
||||
|
||||
# torch.trace doesn't support the offset argument and doesn't support stacking
|
||||
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
|
||||
# Use our wrapped sum to make sure it does upcasting correctly
|
||||
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
|
||||
|
||||
def vector_norm(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
# JustFloat stands for inf | -inf, which are not valid for Literal
|
||||
ord: JustInt | JustFloat = 2,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
# torch.vector_norm incorrectly treats axis=() the same as axis=None
|
||||
if axis == ():
|
||||
out = kwargs.get('out')
|
||||
if out is None:
|
||||
dtype = None
|
||||
if x.dtype == torch.complex64:
|
||||
dtype = torch.float32
|
||||
elif x.dtype == torch.complex128:
|
||||
dtype = torch.float64
|
||||
|
||||
out = torch.zeros_like(x, dtype=dtype)
|
||||
|
||||
# The norm of a single scalar works out to abs(x) in every case except
|
||||
# for ord=0, which is x != 0.
|
||||
if ord == 0:
|
||||
out[:] = (x != 0)
|
||||
else:
|
||||
out[:] = torch.abs(x)
|
||||
return out
|
||||
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
|
||||
|
||||
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
|
||||
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
|
||||
|
||||
_all_ignore = ['torch_linalg', 'sum']
|
||||
|
||||
del linalg_all
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/LICENSE
vendored
Normal file
21
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/LICENSE
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Consortium for Python Data API Standards
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/README.md
vendored
Normal file
1
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/README.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Update this directory using maint_tools/vendor_array_api_extra.sh
|
||||
38
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/__init__.py
vendored
Normal file
38
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/__init__.py
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Extra array functions built on top of the array API standard."""
|
||||
|
||||
from ._delegation import isclose, pad
|
||||
from ._lib._at import at
|
||||
from ._lib._funcs import (
|
||||
apply_where,
|
||||
atleast_nd,
|
||||
broadcast_shapes,
|
||||
cov,
|
||||
create_diagonal,
|
||||
expand_dims,
|
||||
kron,
|
||||
nunique,
|
||||
setdiff1d,
|
||||
sinc,
|
||||
)
|
||||
from ._lib._lazy import lazy_apply
|
||||
|
||||
__version__ = "0.7.1"
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"apply_where",
|
||||
"at",
|
||||
"atleast_nd",
|
||||
"broadcast_shapes",
|
||||
"cov",
|
||||
"create_diagonal",
|
||||
"expand_dims",
|
||||
"isclose",
|
||||
"kron",
|
||||
"lazy_apply",
|
||||
"nunique",
|
||||
"pad",
|
||||
"setdiff1d",
|
||||
"sinc",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
172
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/_delegation.py
vendored
Normal file
172
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/_delegation.py
vendored
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Delegation to existing implementations for Public API Functions."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from types import ModuleType
|
||||
from typing import Literal
|
||||
|
||||
from ._lib import Backend, _funcs
|
||||
from ._lib._utils._compat import array_namespace
|
||||
from ._lib._utils._helpers import asarrays
|
||||
from ._lib._utils._typing import Array
|
||||
|
||||
__all__ = ["isclose", "pad"]
|
||||
|
||||
|
||||
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
|
||||
"""
|
||||
Check whether `xp` is one of the `backends` to delegate to.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xp : array_namespace
|
||||
Array namespace to check.
|
||||
*backends : IsNamespace
|
||||
Arbitrarily many backends (from the ``IsNamespace`` enum) to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
|
||||
"""
|
||||
return any(backend.is_namespace(xp) for backend in backends)
|
||||
|
||||
|
||||
def isclose(
|
||||
a: Array | complex,
|
||||
b: Array | complex,
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
xp: ModuleType | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Return a boolean array where two arrays are element-wise equal within a tolerance.
|
||||
|
||||
The tolerance values are positive, typically very small numbers. The relative
|
||||
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
|
||||
to compare against the absolute difference between `a` and `b`.
|
||||
|
||||
NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
|
||||
Infs are treated as equal if they are in the same place and of the same sign in both
|
||||
arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b : Array | int | float | complex | bool
|
||||
Input objects to compare. At least one must be an array.
|
||||
rtol : array_like, optional
|
||||
The relative tolerance parameter (see Notes).
|
||||
atol : array_like, optional
|
||||
The absolute tolerance parameter (see Notes).
|
||||
equal_nan : bool, optional
|
||||
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
|
||||
equal to NaN's in `b` in the output array.
|
||||
xp : array_namespace, optional
|
||||
The standard-compatible namespace for `a` and `b`. Default: infer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Array
|
||||
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
|
||||
`a` is close to `b`, and ``False`` otherwise.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
The default `atol` is not appropriate for comparing numbers with magnitudes much
|
||||
smaller than one (see notes).
|
||||
|
||||
See Also
|
||||
--------
|
||||
math.isclose : Similar function in stdlib for Python scalars.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For finite values, `isclose` uses the following equation to test whether two
|
||||
floating point values are equivalent::
|
||||
|
||||
absolute(a - b) <= (atol + rtol * absolute(b))
|
||||
|
||||
Unlike the built-in `math.isclose`,
|
||||
the above equation is not symmetric in `a` and `b`,
|
||||
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
|
||||
cases.
|
||||
|
||||
The default value of `atol` is not appropriate when the reference value `b` has
|
||||
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
|
||||
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
|
||||
with default settings. Be sure to select `atol` for the use case at hand, especially
|
||||
for defining the threshold below which a non-zero value in `a` will be considered
|
||||
"close" to a very small or zero value in `b`.
|
||||
|
||||
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
|
||||
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
|
||||
``True``.
|
||||
|
||||
`isclose` is not defined for non-numeric data types.
|
||||
``bool`` is considered a numeric data-type for this purpose.
|
||||
"""
|
||||
xp = array_namespace(a, b) if xp is None else xp
|
||||
|
||||
if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX):
|
||||
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
|
||||
if _delegate(xp, Backend.TORCH):
|
||||
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
|
||||
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
|
||||
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
|
||||
|
||||
|
||||
def pad(
|
||||
x: Array,
|
||||
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
|
||||
mode: Literal["constant"] = "constant",
|
||||
*,
|
||||
constant_values: complex = 0,
|
||||
xp: ModuleType | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Pad the input array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array
|
||||
Input array.
|
||||
pad_width : int or tuple of ints or sequence of pairs of ints
|
||||
Pad the input array with this many elements from each side.
|
||||
If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
|
||||
each pair applies to the corresponding axis of ``x``.
|
||||
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
|
||||
copies of this tuple.
|
||||
mode : str, optional
|
||||
Only "constant" mode is currently supported, which pads with
|
||||
the value passed to `constant_values`.
|
||||
constant_values : python scalar, optional
|
||||
Use this value to pad the input. Default is zero.
|
||||
xp : array_namespace, optional
|
||||
The standard-compatible namespace for `x`. Default: infer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
array
|
||||
The input array,
|
||||
padded with ``pad_width`` elements equal to ``constant_values``.
|
||||
"""
|
||||
xp = array_namespace(x) if xp is None else xp
|
||||
|
||||
if mode != "constant":
|
||||
msg = "Only `'constant'` mode is currently supported"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
|
||||
if _delegate(xp, Backend.TORCH):
|
||||
pad_width = xp.asarray(pad_width)
|
||||
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
|
||||
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
|
||||
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
|
||||
if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE):
|
||||
return xp.pad(x, pad_width, mode, constant_values=constant_values)
|
||||
|
||||
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
|
||||
5
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/_lib/__init__.py
vendored
Normal file
5
venv/lib/python3.12/site-packages/sklearn/externals/array_api_extra/_lib/__init__.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Internals of array-api-extra."""
|
||||
|
||||
from ._backends import Backend
|
||||
|
||||
__all__ = ["Backend"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user