add read me
This commit is contained in:
84
venv/lib/python3.12/site-packages/sklearn/utils/__init__.py
Normal file
84
venv/lib/python3.12/site-packages/sklearn/utils/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Various utilities to help with development."""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from ..exceptions import DataConversionWarning
|
||||
from . import metadata_routing
|
||||
from ._bunch import Bunch
|
||||
from ._chunking import gen_batches, gen_even_slices
|
||||
|
||||
# Make _safe_indexing importable from here for backward compat as this particular
|
||||
# helper is considered semi-private and typically very useful for third-party
|
||||
# libraries that want to comply with scikit-learn's estimator API. In particular,
|
||||
# _safe_indexing was included in our public API documentation despite the leading
|
||||
# `_` in its name.
|
||||
from ._indexing import (
|
||||
_safe_indexing, # noqa: F401
|
||||
resample,
|
||||
shuffle,
|
||||
)
|
||||
from ._mask import safe_mask
|
||||
from ._repr_html.base import _HTMLDocumentationLinkMixin # noqa: F401
|
||||
from ._repr_html.estimator import estimator_html_repr
|
||||
from ._tags import (
|
||||
ClassifierTags,
|
||||
InputTags,
|
||||
RegressorTags,
|
||||
Tags,
|
||||
TargetTags,
|
||||
TransformerTags,
|
||||
get_tags,
|
||||
)
|
||||
from .class_weight import compute_class_weight, compute_sample_weight
|
||||
from .deprecation import deprecated
|
||||
from .discovery import all_estimators
|
||||
from .extmath import safe_sqr
|
||||
from .murmurhash import murmurhash3_32
|
||||
from .validation import (
|
||||
as_float_array,
|
||||
assert_all_finite,
|
||||
check_array,
|
||||
check_consistent_length,
|
||||
check_random_state,
|
||||
check_scalar,
|
||||
check_symmetric,
|
||||
check_X_y,
|
||||
column_or_1d,
|
||||
indexable,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Bunch",
|
||||
"ClassifierTags",
|
||||
"DataConversionWarning",
|
||||
"InputTags",
|
||||
"RegressorTags",
|
||||
"Tags",
|
||||
"TargetTags",
|
||||
"TransformerTags",
|
||||
"all_estimators",
|
||||
"as_float_array",
|
||||
"assert_all_finite",
|
||||
"check_X_y",
|
||||
"check_array",
|
||||
"check_consistent_length",
|
||||
"check_random_state",
|
||||
"check_scalar",
|
||||
"check_symmetric",
|
||||
"column_or_1d",
|
||||
"compute_class_weight",
|
||||
"compute_sample_weight",
|
||||
"deprecated",
|
||||
"estimator_html_repr",
|
||||
"gen_batches",
|
||||
"gen_even_slices",
|
||||
"get_tags",
|
||||
"indexable",
|
||||
"metadata_routing",
|
||||
"murmurhash3_32",
|
||||
"resample",
|
||||
"safe_mask",
|
||||
"safe_sqr",
|
||||
"shuffle",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
33
venv/lib/python3.12/site-packages/sklearn/utils/_arpack.py
Normal file
33
venv/lib/python3.12/site-packages/sklearn/utils/_arpack.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from .validation import check_random_state
|
||||
|
||||
|
||||
def _init_arpack_v0(size, random_state):
|
||||
"""Initialize the starting vector for iteration in ARPACK functions.
|
||||
|
||||
Initialize a ndarray with values sampled from the uniform distribution on
|
||||
[-1, 1]. This initialization model has been chosen to be consistent with
|
||||
the ARPACK one as another initialization can lead to convergence issues.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size : int
|
||||
The size of the eigenvalue vector to be initialized.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
The seed of the pseudo random number generator used to generate a
|
||||
uniform distribution. If int, random_state is the seed used by the
|
||||
random number generator; If RandomState instance, random_state is the
|
||||
random number generator; If None, the random number generator is the
|
||||
RandomState instance used by `np.random`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
v0 : ndarray of shape (size,)
|
||||
The initialized vector.
|
||||
"""
|
||||
random_state = check_random_state(random_state)
|
||||
v0 = random_state.uniform(-1, 1, size)
|
||||
return v0
|
||||
1006
venv/lib/python3.12/site-packages/sklearn/utils/_array_api.py
Normal file
1006
venv/lib/python3.12/site-packages/sklearn/utils/_array_api.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,96 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from functools import update_wrapper, wraps
|
||||
from types import MethodType
|
||||
|
||||
|
||||
class _AvailableIfDescriptor:
|
||||
"""Implements a conditional property using the descriptor protocol.
|
||||
|
||||
Using this class to create a decorator will raise an ``AttributeError``
|
||||
if check(self) returns a falsey value. Note that if check raises an error
|
||||
this will also result in hasattr returning false.
|
||||
|
||||
See https://docs.python.org/3/howto/descriptor.html for an explanation of
|
||||
descriptors.
|
||||
"""
|
||||
|
||||
def __init__(self, fn, check, attribute_name):
|
||||
self.fn = fn
|
||||
self.check = check
|
||||
self.attribute_name = attribute_name
|
||||
|
||||
# update the docstring of the descriptor
|
||||
update_wrapper(self, fn)
|
||||
|
||||
def _check(self, obj, owner):
|
||||
attr_err_msg = (
|
||||
f"This {owner.__name__!r} has no attribute {self.attribute_name!r}"
|
||||
)
|
||||
try:
|
||||
check_result = self.check(obj)
|
||||
except Exception as e:
|
||||
raise AttributeError(attr_err_msg) from e
|
||||
|
||||
if not check_result:
|
||||
raise AttributeError(attr_err_msg)
|
||||
|
||||
def __get__(self, obj, owner=None):
|
||||
if obj is not None:
|
||||
# delegate only on instances, not the classes.
|
||||
# this is to allow access to the docstrings.
|
||||
self._check(obj, owner=owner)
|
||||
out = MethodType(self.fn, obj)
|
||||
|
||||
else:
|
||||
# This makes it possible to use the decorated method as an unbound method,
|
||||
# for instance when monkeypatching.
|
||||
@wraps(self.fn)
|
||||
def out(*args, **kwargs):
|
||||
self._check(args[0], owner=owner)
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def available_if(check):
|
||||
"""An attribute that is available only if check returns a truthy value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
check : callable
|
||||
When passed the object with the decorated method, this should return
|
||||
a truthy value if the attribute is available, and either return False
|
||||
or raise an AttributeError if not available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
callable
|
||||
Callable makes the decorated method available if `check` returns
|
||||
a truthy value, otherwise the decorated method is unavailable.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils.metaestimators import available_if
|
||||
>>> class HelloIfEven:
|
||||
... def __init__(self, x):
|
||||
... self.x = x
|
||||
...
|
||||
... def _x_is_even(self):
|
||||
... return self.x % 2 == 0
|
||||
...
|
||||
... @available_if(_x_is_even)
|
||||
... def say_hello(self):
|
||||
... print("Hello")
|
||||
...
|
||||
>>> obj = HelloIfEven(1)
|
||||
>>> hasattr(obj, "say_hello")
|
||||
False
|
||||
>>> obj.x = 2
|
||||
>>> hasattr(obj, "say_hello")
|
||||
True
|
||||
>>> obj.say_hello()
|
||||
Hello
|
||||
"""
|
||||
return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__)
|
||||
70
venv/lib/python3.12/site-packages/sklearn/utils/_bunch.py
Normal file
70
venv/lib/python3.12/site-packages/sklearn/utils/_bunch.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
class Bunch(dict):
|
||||
"""Container object exposing keys as attributes.
|
||||
|
||||
Bunch objects are sometimes used as an output for functions and methods.
|
||||
They extend dictionaries by enabling values to be accessed by key,
|
||||
`bunch["value_key"]`, or by an attribute, `bunch.value_key`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils import Bunch
|
||||
>>> b = Bunch(a=1, b=2)
|
||||
>>> b['b']
|
||||
2
|
||||
>>> b.b
|
||||
2
|
||||
>>> b.a = 3
|
||||
>>> b['a']
|
||||
3
|
||||
>>> b.c = 6
|
||||
>>> b['c']
|
||||
6
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(kwargs)
|
||||
|
||||
# Map from deprecated key to warning message
|
||||
self.__dict__["_deprecated_key_to_warnings"] = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.__dict__.get("_deprecated_key_to_warnings", {}):
|
||||
warnings.warn(
|
||||
self._deprecated_key_to_warnings[key],
|
||||
FutureWarning,
|
||||
)
|
||||
return super().__getitem__(key)
|
||||
|
||||
def _set_deprecated(self, value, *, new_key, deprecated_key, warning_message):
|
||||
"""Set key in dictionary to be deprecated with its warning message."""
|
||||
self.__dict__["_deprecated_key_to_warnings"][deprecated_key] = warning_message
|
||||
self[new_key] = self[deprecated_key] = value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __dir__(self):
|
||||
return self.keys()
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setstate__(self, state):
|
||||
# Bunch pickles generated with scikit-learn 0.16.* have an non
|
||||
# empty __dict__. This causes a surprising behaviour when
|
||||
# loading these pickles scikit-learn 0.17: reading bunch.key
|
||||
# uses __dict__ but assigning to bunch.key use __setattr__ and
|
||||
# only changes bunch['key']. More details can be found at:
|
||||
# https://github.com/scikit-learn/scikit-learn/issues/6196.
|
||||
# Overriding __setstate__ to be a noop has the effect of
|
||||
# ignoring the pickled __dict__
|
||||
pass
|
||||
178
venv/lib/python3.12/site-packages/sklearn/utils/_chunking.py
Normal file
178
venv/lib/python3.12/site-packages/sklearn/utils/_chunking.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import warnings
|
||||
from itertools import islice
|
||||
from numbers import Integral
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._config import get_config
|
||||
from ._param_validation import Interval, validate_params
|
||||
|
||||
|
||||
def chunk_generator(gen, chunksize):
|
||||
"""Chunk generator, ``gen`` into lists of length ``chunksize``. The last
|
||||
chunk may have a length less than ``chunksize``."""
|
||||
while True:
|
||||
chunk = list(islice(gen, chunksize))
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
@validate_params(
|
||||
{
|
||||
"n": [Interval(Integral, 1, None, closed="left")],
|
||||
"batch_size": [Interval(Integral, 1, None, closed="left")],
|
||||
"min_batch_size": [Interval(Integral, 0, None, closed="left")],
|
||||
},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def gen_batches(n, batch_size, *, min_batch_size=0):
|
||||
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
|
||||
|
||||
The last slice may contain less than `batch_size` elements, when
|
||||
`batch_size` does not divide `n`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n : int
|
||||
Size of the sequence.
|
||||
batch_size : int
|
||||
Number of elements in each batch.
|
||||
min_batch_size : int, default=0
|
||||
Minimum number of elements in each batch.
|
||||
|
||||
Yields
|
||||
------
|
||||
slice of `batch_size` elements
|
||||
|
||||
See Also
|
||||
--------
|
||||
gen_even_slices: Generator to create n_packs slices going up to n.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils import gen_batches
|
||||
>>> list(gen_batches(7, 3))
|
||||
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
|
||||
>>> list(gen_batches(6, 3))
|
||||
[slice(0, 3, None), slice(3, 6, None)]
|
||||
>>> list(gen_batches(2, 3))
|
||||
[slice(0, 2, None)]
|
||||
>>> list(gen_batches(7, 3, min_batch_size=0))
|
||||
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
|
||||
>>> list(gen_batches(7, 3, min_batch_size=2))
|
||||
[slice(0, 3, None), slice(3, 7, None)]
|
||||
"""
|
||||
start = 0
|
||||
for _ in range(int(n // batch_size)):
|
||||
end = start + batch_size
|
||||
if end + min_batch_size > n:
|
||||
continue
|
||||
yield slice(start, end)
|
||||
start = end
|
||||
if start < n:
|
||||
yield slice(start, n)
|
||||
|
||||
|
||||
@validate_params(
|
||||
{
|
||||
"n": [Interval(Integral, 1, None, closed="left")],
|
||||
"n_packs": [Interval(Integral, 1, None, closed="left")],
|
||||
"n_samples": [Interval(Integral, 1, None, closed="left"), None],
|
||||
},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def gen_even_slices(n, n_packs, *, n_samples=None):
|
||||
"""Generator to create `n_packs` evenly spaced slices going up to `n`.
|
||||
|
||||
If `n_packs` does not divide `n`, except for the first `n % n_packs`
|
||||
slices, remaining slices may contain fewer elements.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n : int
|
||||
Size of the sequence.
|
||||
n_packs : int
|
||||
Number of slices to generate.
|
||||
n_samples : int, default=None
|
||||
Number of samples. Pass `n_samples` when the slices are to be used for
|
||||
sparse matrix indexing; slicing off-the-end raises an exception, while
|
||||
it works for NumPy arrays.
|
||||
|
||||
Yields
|
||||
------
|
||||
`slice` representing a set of indices from 0 to n.
|
||||
|
||||
See Also
|
||||
--------
|
||||
gen_batches: Generator to create slices containing batch_size elements
|
||||
from 0 to n.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils import gen_even_slices
|
||||
>>> list(gen_even_slices(10, 1))
|
||||
[slice(0, 10, None)]
|
||||
>>> list(gen_even_slices(10, 10))
|
||||
[slice(0, 1, None), slice(1, 2, None), ..., slice(9, 10, None)]
|
||||
>>> list(gen_even_slices(10, 5))
|
||||
[slice(0, 2, None), slice(2, 4, None), ..., slice(8, 10, None)]
|
||||
>>> list(gen_even_slices(10, 3))
|
||||
[slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)]
|
||||
"""
|
||||
start = 0
|
||||
for pack_num in range(n_packs):
|
||||
this_n = n // n_packs
|
||||
if pack_num < n % n_packs:
|
||||
this_n += 1
|
||||
if this_n > 0:
|
||||
end = start + this_n
|
||||
if n_samples is not None:
|
||||
end = min(n_samples, end)
|
||||
yield slice(start, end, None)
|
||||
start = end
|
||||
|
||||
|
||||
def get_chunk_n_rows(row_bytes, *, max_n_rows=None, working_memory=None):
|
||||
"""Calculate how many rows can be processed within `working_memory`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row_bytes : int
|
||||
The expected number of bytes of memory that will be consumed
|
||||
during the processing of each row.
|
||||
max_n_rows : int, default=None
|
||||
The maximum return value.
|
||||
working_memory : int or float, default=None
|
||||
The number of rows to fit inside this number of MiB will be
|
||||
returned. When None (default), the value of
|
||||
``sklearn.get_config()['working_memory']`` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of rows which can be processed within `working_memory`.
|
||||
|
||||
Warns
|
||||
-----
|
||||
Issues a UserWarning if `row_bytes exceeds `working_memory` MiB.
|
||||
"""
|
||||
|
||||
if working_memory is None:
|
||||
working_memory = get_config()["working_memory"]
|
||||
|
||||
chunk_n_rows = int(working_memory * (2**20) // row_bytes)
|
||||
if max_n_rows is not None:
|
||||
chunk_n_rows = min(chunk_n_rows, max_n_rows)
|
||||
if chunk_n_rows < 1:
|
||||
warnings.warn(
|
||||
"Could not adhere to working_memory config. "
|
||||
"Currently %.0fMiB, %.0fMiB required."
|
||||
% (working_memory, np.ceil(row_bytes * 2**-20))
|
||||
)
|
||||
chunk_n_rows = 1
|
||||
return chunk_n_rows
|
||||
Binary file not shown.
@@ -0,0 +1,41 @@
|
||||
from cython cimport floating
|
||||
|
||||
|
||||
cpdef enum BLAS_Order:
|
||||
RowMajor # C contiguous
|
||||
ColMajor # Fortran contiguous
|
||||
|
||||
|
||||
cpdef enum BLAS_Trans:
|
||||
NoTrans = 110 # correspond to 'n'
|
||||
Trans = 116 # correspond to 't'
|
||||
|
||||
|
||||
# BLAS Level 1 ################################################################
|
||||
cdef floating _dot(int, const floating*, int, const floating*, int) noexcept nogil
|
||||
|
||||
cdef floating _asum(int, const floating*, int) noexcept nogil
|
||||
|
||||
cdef void _axpy(int, floating, const floating*, int, floating*, int) noexcept nogil
|
||||
|
||||
cdef floating _nrm2(int, const floating*, int) noexcept nogil
|
||||
|
||||
cdef void _copy(int, const floating*, int, const floating*, int) noexcept nogil
|
||||
|
||||
cdef void _scal(int, floating, const floating*, int) noexcept nogil
|
||||
|
||||
cdef void _rotg(floating*, floating*, floating*, floating*) noexcept nogil
|
||||
|
||||
cdef void _rot(int, floating*, int, floating*, int, floating, floating) noexcept nogil
|
||||
|
||||
# BLAS Level 2 ################################################################
|
||||
cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, const floating*, int,
|
||||
const floating*, int, floating, floating*, int) noexcept nogil
|
||||
|
||||
cdef void _ger(BLAS_Order, int, int, floating, const floating*, int, const floating*,
|
||||
int, floating*, int) noexcept nogil
|
||||
|
||||
# BLASLevel 3 ################################################################
|
||||
cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating,
|
||||
const floating*, int, const floating*, int, floating, floating*,
|
||||
int) noexcept nogil
|
||||
239
venv/lib/python3.12/site-packages/sklearn/utils/_cython_blas.pyx
Normal file
239
venv/lib/python3.12/site-packages/sklearn/utils/_cython_blas.pyx
Normal file
@@ -0,0 +1,239 @@
|
||||
from cython cimport floating
|
||||
|
||||
from scipy.linalg.cython_blas cimport sdot, ddot
|
||||
from scipy.linalg.cython_blas cimport sasum, dasum
|
||||
from scipy.linalg.cython_blas cimport saxpy, daxpy
|
||||
from scipy.linalg.cython_blas cimport snrm2, dnrm2
|
||||
from scipy.linalg.cython_blas cimport scopy, dcopy
|
||||
from scipy.linalg.cython_blas cimport sscal, dscal
|
||||
from scipy.linalg.cython_blas cimport srotg, drotg
|
||||
from scipy.linalg.cython_blas cimport srot, drot
|
||||
from scipy.linalg.cython_blas cimport sgemv, dgemv
|
||||
from scipy.linalg.cython_blas cimport sger, dger
|
||||
from scipy.linalg.cython_blas cimport sgemm, dgemm
|
||||
|
||||
|
||||
################
|
||||
# BLAS Level 1 #
|
||||
################
|
||||
|
||||
cdef floating _dot(int n, const floating *x, int incx,
|
||||
const floating *y, int incy) noexcept nogil:
|
||||
"""x.T.y"""
|
||||
if floating is float:
|
||||
return sdot(&n, <float *> x, &incx, <float *> y, &incy)
|
||||
else:
|
||||
return ddot(&n, <double *> x, &incx, <double *> y, &incy)
|
||||
|
||||
|
||||
cpdef _dot_memview(const floating[::1] x, const floating[::1] y):
|
||||
return _dot(x.shape[0], &x[0], 1, &y[0], 1)
|
||||
|
||||
|
||||
cdef floating _asum(int n, const floating *x, int incx) noexcept nogil:
|
||||
"""sum(|x_i|)"""
|
||||
if floating is float:
|
||||
return sasum(&n, <float *> x, &incx)
|
||||
else:
|
||||
return dasum(&n, <double *> x, &incx)
|
||||
|
||||
|
||||
cpdef _asum_memview(const floating[::1] x):
|
||||
return _asum(x.shape[0], &x[0], 1)
|
||||
|
||||
|
||||
cdef void _axpy(int n, floating alpha, const floating *x, int incx,
|
||||
floating *y, int incy) noexcept nogil:
|
||||
"""y := alpha * x + y"""
|
||||
if floating is float:
|
||||
saxpy(&n, &alpha, <float *> x, &incx, y, &incy)
|
||||
else:
|
||||
daxpy(&n, &alpha, <double *> x, &incx, y, &incy)
|
||||
|
||||
|
||||
cpdef _axpy_memview(floating alpha, const floating[::1] x, floating[::1] y):
|
||||
_axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1)
|
||||
|
||||
|
||||
cdef floating _nrm2(int n, const floating *x, int incx) noexcept nogil:
|
||||
"""sqrt(sum((x_i)^2))"""
|
||||
if floating is float:
|
||||
return snrm2(&n, <float *> x, &incx)
|
||||
else:
|
||||
return dnrm2(&n, <double *> x, &incx)
|
||||
|
||||
|
||||
cpdef _nrm2_memview(const floating[::1] x):
|
||||
return _nrm2(x.shape[0], &x[0], 1)
|
||||
|
||||
|
||||
cdef void _copy(int n, const floating *x, int incx, const floating *y, int incy) noexcept nogil:
|
||||
"""y := x"""
|
||||
if floating is float:
|
||||
scopy(&n, <float *> x, &incx, <float *> y, &incy)
|
||||
else:
|
||||
dcopy(&n, <double *> x, &incx, <double *> y, &incy)
|
||||
|
||||
|
||||
cpdef _copy_memview(const floating[::1] x, const floating[::1] y):
|
||||
_copy(x.shape[0], &x[0], 1, &y[0], 1)
|
||||
|
||||
|
||||
cdef void _scal(int n, floating alpha, const floating *x, int incx) noexcept nogil:
|
||||
"""x := alpha * x"""
|
||||
if floating is float:
|
||||
sscal(&n, &alpha, <float *> x, &incx)
|
||||
else:
|
||||
dscal(&n, &alpha, <double *> x, &incx)
|
||||
|
||||
|
||||
cpdef _scal_memview(floating alpha, const floating[::1] x):
|
||||
_scal(x.shape[0], alpha, &x[0], 1)
|
||||
|
||||
|
||||
cdef void _rotg(floating *a, floating *b, floating *c, floating *s) noexcept nogil:
|
||||
"""Generate plane rotation"""
|
||||
if floating is float:
|
||||
srotg(a, b, c, s)
|
||||
else:
|
||||
drotg(a, b, c, s)
|
||||
|
||||
|
||||
cpdef _rotg_memview(floating a, floating b, floating c, floating s):
|
||||
_rotg(&a, &b, &c, &s)
|
||||
return a, b, c, s
|
||||
|
||||
|
||||
cdef void _rot(int n, floating *x, int incx, floating *y, int incy,
|
||||
floating c, floating s) noexcept nogil:
|
||||
"""Apply plane rotation"""
|
||||
if floating is float:
|
||||
srot(&n, x, &incx, y, &incy, &c, &s)
|
||||
else:
|
||||
drot(&n, x, &incx, y, &incy, &c, &s)
|
||||
|
||||
|
||||
cpdef _rot_memview(floating[::1] x, floating[::1] y, floating c, floating s):
|
||||
_rot(x.shape[0], &x[0], 1, &y[0], 1, c, s)
|
||||
|
||||
|
||||
################
|
||||
# BLAS Level 2 #
|
||||
################
|
||||
|
||||
cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
|
||||
const floating *A, int lda, const floating *x, int incx,
|
||||
floating beta, floating *y, int incy) noexcept nogil:
|
||||
"""y := alpha * op(A).x + beta * y"""
|
||||
cdef char ta_ = ta
|
||||
if order == BLAS_Order.RowMajor:
|
||||
ta_ = BLAS_Trans.NoTrans if ta == BLAS_Trans.Trans else BLAS_Trans.Trans
|
||||
if floating is float:
|
||||
sgemv(&ta_, &n, &m, &alpha, <float *> A, &lda, <float *> x,
|
||||
&incx, &beta, y, &incy)
|
||||
else:
|
||||
dgemv(&ta_, &n, &m, &alpha, <double *> A, &lda, <double *> x,
|
||||
&incx, &beta, y, &incy)
|
||||
else:
|
||||
if floating is float:
|
||||
sgemv(&ta_, &m, &n, &alpha, <float *> A, &lda, <float *> x,
|
||||
&incx, &beta, y, &incy)
|
||||
else:
|
||||
dgemv(&ta_, &m, &n, &alpha, <double *> A, &lda, <double *> x,
|
||||
&incx, &beta, y, &incy)
|
||||
|
||||
|
||||
cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A,
|
||||
const floating[::1] x, floating beta, floating[::1] y):
|
||||
cdef:
|
||||
int m = A.shape[0]
|
||||
int n = A.shape[1]
|
||||
BLAS_Order order = (
|
||||
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
|
||||
)
|
||||
int lda = m if order == BLAS_Order.ColMajor else n
|
||||
|
||||
_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1)
|
||||
|
||||
|
||||
cdef void _ger(BLAS_Order order, int m, int n, floating alpha,
|
||||
const floating *x, int incx, const floating *y,
|
||||
int incy, floating *A, int lda) noexcept nogil:
|
||||
"""A := alpha * x.y.T + A"""
|
||||
if order == BLAS_Order.RowMajor:
|
||||
if floating is float:
|
||||
sger(&n, &m, &alpha, <float *> y, &incy, <float *> x, &incx, A, &lda)
|
||||
else:
|
||||
dger(&n, &m, &alpha, <double *> y, &incy, <double *> x, &incx, A, &lda)
|
||||
else:
|
||||
if floating is float:
|
||||
sger(&m, &n, &alpha, <float *> x, &incx, <float *> y, &incy, A, &lda)
|
||||
else:
|
||||
dger(&m, &n, &alpha, <double *> x, &incx, <double *> y, &incy, A, &lda)
|
||||
|
||||
|
||||
cpdef _ger_memview(floating alpha, const floating[::1] x,
|
||||
const floating[::1] y, floating[:, :] A):
|
||||
cdef:
|
||||
int m = A.shape[0]
|
||||
int n = A.shape[1]
|
||||
BLAS_Order order = (
|
||||
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
|
||||
)
|
||||
int lda = m if order == BLAS_Order.ColMajor else n
|
||||
|
||||
_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
|
||||
|
||||
|
||||
################
|
||||
# BLAS Level 3 #
|
||||
################
|
||||
|
||||
cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
|
||||
int k, floating alpha, const floating *A, int lda, const floating *B,
|
||||
int ldb, floating beta, floating *C, int ldc) noexcept nogil:
|
||||
"""C := alpha * op(A).op(B) + beta * C"""
|
||||
# TODO: Remove the pointer casts below once SciPy uses const-qualification.
|
||||
# See: https://github.com/scipy/scipy/issues/14262
|
||||
cdef:
|
||||
char ta_ = ta
|
||||
char tb_ = tb
|
||||
if order == BLAS_Order.RowMajor:
|
||||
if floating is float:
|
||||
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, <float*>B,
|
||||
&ldb, <float*>A, &lda, &beta, C, &ldc)
|
||||
else:
|
||||
dgemm(&tb_, &ta_, &n, &m, &k, &alpha, <double*>B,
|
||||
&ldb, <double*>A, &lda, &beta, C, &ldc)
|
||||
else:
|
||||
if floating is float:
|
||||
sgemm(&ta_, &tb_, &m, &n, &k, &alpha, <float*>A,
|
||||
&lda, <float*>B, &ldb, &beta, C, &ldc)
|
||||
else:
|
||||
dgemm(&ta_, &tb_, &m, &n, &k, &alpha, <double*>A,
|
||||
&lda, <double*>B, &ldb, &beta, C, &ldc)
|
||||
|
||||
|
||||
cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
|
||||
const floating[:, :] A, const floating[:, :] B, floating beta,
|
||||
floating[:, :] C):
|
||||
cdef:
|
||||
int m = A.shape[0] if ta == BLAS_Trans.NoTrans else A.shape[1]
|
||||
int n = B.shape[1] if tb == BLAS_Trans.NoTrans else B.shape[0]
|
||||
int k = A.shape[1] if ta == BLAS_Trans.NoTrans else A.shape[0]
|
||||
int lda, ldb, ldc
|
||||
BLAS_Order order = (
|
||||
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
|
||||
)
|
||||
|
||||
if order == BLAS_Order.RowMajor:
|
||||
lda = k if ta == BLAS_Trans.NoTrans else m
|
||||
ldb = n if tb == BLAS_Trans.NoTrans else k
|
||||
ldc = n
|
||||
else:
|
||||
lda = m if ta == BLAS_Trans.NoTrans else k
|
||||
ldb = k if tb == BLAS_Trans.NoTrans else n
|
||||
ldc = m
|
||||
|
||||
_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],
|
||||
lda, &B[0, 0], ldb, beta, &C[0, 0], ldc)
|
||||
376
venv/lib/python3.12/site-packages/sklearn/utils/_encode.py
Normal file
376
venv/lib/python3.12/site-packages/sklearn/utils/_encode.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from collections import Counter
|
||||
from contextlib import suppress
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._array_api import (
|
||||
_isin,
|
||||
_searchsorted,
|
||||
device,
|
||||
get_namespace,
|
||||
xpx,
|
||||
)
|
||||
from ._missing import is_scalar_nan
|
||||
|
||||
|
||||
def _unique(values, *, return_inverse=False, return_counts=False):
|
||||
"""Helper function to find unique values with support for python objects.
|
||||
|
||||
Uses pure python method for object dtype, and numpy method for
|
||||
all other dtypes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values : ndarray
|
||||
Values to check for unknowns.
|
||||
|
||||
return_inverse : bool, default=False
|
||||
If True, also return the indices of the unique values.
|
||||
|
||||
return_counts : bool, default=False
|
||||
If True, also return the number of times each unique item appears in
|
||||
values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
unique : ndarray
|
||||
The sorted unique values.
|
||||
|
||||
unique_inverse : ndarray
|
||||
The indices to reconstruct the original array from the unique array.
|
||||
Only provided if `return_inverse` is True.
|
||||
|
||||
unique_counts : ndarray
|
||||
The number of times each of the unique values comes up in the original
|
||||
array. Only provided if `return_counts` is True.
|
||||
"""
|
||||
if values.dtype == object:
|
||||
return _unique_python(
|
||||
values, return_inverse=return_inverse, return_counts=return_counts
|
||||
)
|
||||
# numerical
|
||||
return _unique_np(
|
||||
values, return_inverse=return_inverse, return_counts=return_counts
|
||||
)
|
||||
|
||||
|
||||
def _unique_np(values, return_inverse=False, return_counts=False):
|
||||
"""Helper function to find unique values for numpy arrays that correctly
|
||||
accounts for nans. See `_unique` documentation for details."""
|
||||
xp, _ = get_namespace(values)
|
||||
|
||||
inverse, counts = None, None
|
||||
|
||||
if return_inverse and return_counts:
|
||||
uniques, _, inverse, counts = xp.unique_all(values)
|
||||
elif return_inverse:
|
||||
uniques, inverse = xp.unique_inverse(values)
|
||||
elif return_counts:
|
||||
uniques, counts = xp.unique_counts(values)
|
||||
else:
|
||||
uniques = xp.unique_values(values)
|
||||
|
||||
# np.unique will have duplicate missing values at the end of `uniques`
|
||||
# here we clip the nans and remove it from uniques
|
||||
if uniques.size and is_scalar_nan(uniques[-1]):
|
||||
nan_idx = _searchsorted(uniques, xp.nan, xp=xp)
|
||||
uniques = uniques[: nan_idx + 1]
|
||||
if return_inverse:
|
||||
inverse[inverse > nan_idx] = nan_idx
|
||||
|
||||
if return_counts:
|
||||
counts[nan_idx] = xp.sum(counts[nan_idx:])
|
||||
counts = counts[: nan_idx + 1]
|
||||
|
||||
ret = (uniques,)
|
||||
|
||||
if return_inverse:
|
||||
ret += (inverse,)
|
||||
|
||||
if return_counts:
|
||||
ret += (counts,)
|
||||
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
|
||||
class MissingValues(NamedTuple):
|
||||
"""Data class for missing data information"""
|
||||
|
||||
nan: bool
|
||||
none: bool
|
||||
|
||||
def to_list(self):
|
||||
"""Convert tuple to a list where None is always first."""
|
||||
output = []
|
||||
if self.none:
|
||||
output.append(None)
|
||||
if self.nan:
|
||||
output.append(np.nan)
|
||||
return output
|
||||
|
||||
|
||||
def _extract_missing(values):
|
||||
"""Extract missing values from `values`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values: set
|
||||
Set of values to extract missing from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: set
|
||||
Set with missing values extracted.
|
||||
|
||||
missing_values: MissingValues
|
||||
Object with missing value information.
|
||||
"""
|
||||
missing_values_set = {
|
||||
value for value in values if value is None or is_scalar_nan(value)
|
||||
}
|
||||
|
||||
if not missing_values_set:
|
||||
return values, MissingValues(nan=False, none=False)
|
||||
|
||||
if None in missing_values_set:
|
||||
if len(missing_values_set) == 1:
|
||||
output_missing_values = MissingValues(nan=False, none=True)
|
||||
else:
|
||||
# If there is more than one missing value, then it has to be
|
||||
# float('nan') or np.nan
|
||||
output_missing_values = MissingValues(nan=True, none=True)
|
||||
else:
|
||||
output_missing_values = MissingValues(nan=True, none=False)
|
||||
|
||||
# create set without the missing values
|
||||
output = values - missing_values_set
|
||||
return output, output_missing_values
|
||||
|
||||
|
||||
class _nandict(dict):
|
||||
"""Dictionary with support for nans."""
|
||||
|
||||
def __init__(self, mapping):
|
||||
super().__init__(mapping)
|
||||
for key, value in mapping.items():
|
||||
if is_scalar_nan(key):
|
||||
self.nan_value = value
|
||||
break
|
||||
|
||||
def __missing__(self, key):
|
||||
if hasattr(self, "nan_value") and is_scalar_nan(key):
|
||||
return self.nan_value
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
def _map_to_integer(values, uniques):
|
||||
"""Map values based on its position in uniques."""
|
||||
xp, _ = get_namespace(values, uniques)
|
||||
table = _nandict({val: i for i, val in enumerate(uniques)})
|
||||
return xp.asarray([table[v] for v in values], device=device(values))
|
||||
|
||||
|
||||
def _unique_python(values, *, return_inverse, return_counts):
|
||||
# Only used in `_uniques`, see docstring there for details
|
||||
try:
|
||||
uniques_set = set(values)
|
||||
uniques_set, missing_values = _extract_missing(uniques_set)
|
||||
|
||||
uniques = sorted(uniques_set)
|
||||
uniques.extend(missing_values.to_list())
|
||||
uniques = np.array(uniques, dtype=values.dtype)
|
||||
except TypeError:
|
||||
types = sorted(t.__qualname__ for t in set(type(v) for v in values))
|
||||
raise TypeError(
|
||||
"Encoders require their input argument must be uniformly "
|
||||
f"strings or numbers. Got {types}"
|
||||
)
|
||||
ret = (uniques,)
|
||||
|
||||
if return_inverse:
|
||||
ret += (_map_to_integer(values, uniques),)
|
||||
|
||||
if return_counts:
|
||||
ret += (_get_counts(values, uniques),)
|
||||
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
|
||||
def _encode(values, *, uniques, check_unknown=True):
|
||||
"""Helper function to encode values into [0, n_uniques - 1].
|
||||
|
||||
Uses pure python method for object dtype, and numpy method for
|
||||
all other dtypes.
|
||||
The numpy method has the limitation that the `uniques` need to
|
||||
be sorted. Importantly, this is not checked but assumed to already be
|
||||
the case. The calling method needs to ensure this for all non-object
|
||||
values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values : ndarray
|
||||
Values to encode.
|
||||
uniques : ndarray
|
||||
The unique values in `values`. If the dtype is not object, then
|
||||
`uniques` needs to be sorted.
|
||||
check_unknown : bool, default=True
|
||||
If True, check for values in `values` that are not in `unique`
|
||||
and raise an error. This is ignored for object dtype, and treated as
|
||||
True in this case. This parameter is useful for
|
||||
_BaseEncoder._transform() to avoid calling _check_unknown()
|
||||
twice.
|
||||
|
||||
Returns
|
||||
-------
|
||||
encoded : ndarray
|
||||
Encoded values
|
||||
"""
|
||||
xp, _ = get_namespace(values, uniques)
|
||||
if not xp.isdtype(values.dtype, "numeric"):
|
||||
try:
|
||||
return _map_to_integer(values, uniques)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"y contains previously unseen labels: {e}")
|
||||
else:
|
||||
if check_unknown:
|
||||
diff = _check_unknown(values, uniques)
|
||||
if diff:
|
||||
raise ValueError(f"y contains previously unseen labels: {diff}")
|
||||
return _searchsorted(uniques, values, xp=xp)
|
||||
|
||||
|
||||
def _check_unknown(values, known_values, return_mask=False):
|
||||
"""
|
||||
Helper function to check for unknowns in values to be encoded.
|
||||
|
||||
Uses pure python method for object dtype, and numpy method for
|
||||
all other dtypes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values : array
|
||||
Values to check for unknowns.
|
||||
known_values : array
|
||||
Known values. Must be unique.
|
||||
return_mask : bool, default=False
|
||||
If True, return a mask of the same shape as `values` indicating
|
||||
the valid values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
diff : list
|
||||
The unique values present in `values` and not in `know_values`.
|
||||
valid_mask : boolean array
|
||||
Additionally returned if ``return_mask=True``.
|
||||
|
||||
"""
|
||||
xp, _ = get_namespace(values, known_values)
|
||||
valid_mask = None
|
||||
|
||||
if not xp.isdtype(values.dtype, "numeric"):
|
||||
values_set = set(values)
|
||||
values_set, missing_in_values = _extract_missing(values_set)
|
||||
|
||||
uniques_set = set(known_values)
|
||||
uniques_set, missing_in_uniques = _extract_missing(uniques_set)
|
||||
diff = values_set - uniques_set
|
||||
|
||||
nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan
|
||||
none_in_diff = missing_in_values.none and not missing_in_uniques.none
|
||||
|
||||
def is_valid(value):
|
||||
return (
|
||||
value in uniques_set
|
||||
or (missing_in_uniques.none and value is None)
|
||||
or (missing_in_uniques.nan and is_scalar_nan(value))
|
||||
)
|
||||
|
||||
if return_mask:
|
||||
if diff or nan_in_diff or none_in_diff:
|
||||
valid_mask = xp.array([is_valid(value) for value in values])
|
||||
else:
|
||||
valid_mask = xp.ones(len(values), dtype=xp.bool)
|
||||
|
||||
diff = list(diff)
|
||||
if none_in_diff:
|
||||
diff.append(None)
|
||||
if nan_in_diff:
|
||||
diff.append(np.nan)
|
||||
else:
|
||||
unique_values = xp.unique_values(values)
|
||||
diff = xpx.setdiff1d(unique_values, known_values, assume_unique=True, xp=xp)
|
||||
if return_mask:
|
||||
if diff.size:
|
||||
valid_mask = _isin(values, known_values, xp)
|
||||
else:
|
||||
valid_mask = xp.ones(len(values), dtype=xp.bool)
|
||||
|
||||
# check for nans in the known_values
|
||||
if xp.any(xp.isnan(known_values)):
|
||||
diff_is_nan = xp.isnan(diff)
|
||||
if xp.any(diff_is_nan):
|
||||
# removes nan from valid_mask
|
||||
if diff.size and return_mask:
|
||||
is_nan = xp.isnan(values)
|
||||
valid_mask[is_nan] = 1
|
||||
|
||||
# remove nan from diff
|
||||
diff = diff[~diff_is_nan]
|
||||
diff = list(diff)
|
||||
|
||||
if return_mask:
|
||||
return diff, valid_mask
|
||||
return diff
|
||||
|
||||
|
||||
class _NaNCounter(Counter):
|
||||
"""Counter with support for nan values."""
|
||||
|
||||
def __init__(self, items):
|
||||
super().__init__(self._generate_items(items))
|
||||
|
||||
def _generate_items(self, items):
|
||||
"""Generate items without nans. Stores the nan counts separately."""
|
||||
for item in items:
|
||||
if not is_scalar_nan(item):
|
||||
yield item
|
||||
continue
|
||||
if not hasattr(self, "nan_count"):
|
||||
self.nan_count = 0
|
||||
self.nan_count += 1
|
||||
|
||||
def __missing__(self, key):
|
||||
if hasattr(self, "nan_count") and is_scalar_nan(key):
|
||||
return self.nan_count
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
def _get_counts(values, uniques):
|
||||
"""Get the count of each of the `uniques` in `values`.
|
||||
|
||||
The counts will use the order passed in by `uniques`. For non-object dtypes,
|
||||
`uniques` is assumed to be sorted and `np.nan` is at the end.
|
||||
"""
|
||||
if values.dtype.kind in "OU":
|
||||
counter = _NaNCounter(values)
|
||||
output = np.zeros(len(uniques), dtype=np.int64)
|
||||
for i, item in enumerate(uniques):
|
||||
with suppress(KeyError):
|
||||
output[i] = counter[item]
|
||||
return output
|
||||
|
||||
unique_values, counts = _unique_np(values, return_counts=True)
|
||||
|
||||
# Recorder unique_values based on input: `uniques`
|
||||
uniques_in_values = np.isin(uniques, unique_values, assume_unique=True)
|
||||
if np.isnan(unique_values[-1]) and np.isnan(uniques[-1]):
|
||||
uniques_in_values[-1] = True
|
||||
|
||||
unique_valid_indices = np.searchsorted(unique_values, uniques[uniques_in_values])
|
||||
output = np.zeros_like(uniques, dtype=np.int64)
|
||||
output[uniques_in_values] = counts[unique_valid_indices]
|
||||
return output
|
||||
@@ -0,0 +1,34 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import warnings
|
||||
|
||||
from ._repr_html.base import _HTMLDocumentationLinkMixin
|
||||
from ._repr_html.estimator import (
|
||||
_get_visual_block,
|
||||
_IDCounter,
|
||||
_VisualBlock,
|
||||
_write_estimator_html,
|
||||
_write_label_html,
|
||||
estimator_html_repr,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"_HTMLDocumentationLinkMixin",
|
||||
"_IDCounter",
|
||||
"_VisualBlock",
|
||||
"_get_visual_block",
|
||||
"_write_estimator_html",
|
||||
"_write_label_html",
|
||||
"estimator_html_repr",
|
||||
]
|
||||
|
||||
# TODO(1.8): Remove the entire module
|
||||
warnings.warn(
|
||||
"Importing from sklearn.utils._estimator_html_repr is deprecated. The tools have "
|
||||
"been moved to sklearn.utils._repr_html. Be aware that this module is private and "
|
||||
"may be subject to change in the future. The module _estimator_html_repr will be "
|
||||
"removed in 1.8.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""
|
||||
Uses C++ map containers for fast dict-like behavior with keys being
|
||||
integers, and values float.
|
||||
"""
|
||||
|
||||
from libcpp.map cimport map as cpp_map
|
||||
|
||||
from ._typedefs cimport float64_t, intp_t
|
||||
|
||||
|
||||
###############################################################################
|
||||
# An object to be used in Python
|
||||
|
||||
cdef class IntFloatDict:
|
||||
cdef cpp_map[intp_t, float64_t] my_map
|
||||
cdef _to_arrays(self, intp_t [:] keys, float64_t [:] values)
|
||||
137
venv/lib/python3.12/site-packages/sklearn/utils/_fast_dict.pyx
Normal file
137
venv/lib/python3.12/site-packages/sklearn/utils/_fast_dict.pyx
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Uses C++ map containers for fast dict-like behavior with keys being
|
||||
integers, and values float.
|
||||
"""
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# C++
|
||||
from cython.operator cimport dereference as deref, preincrement as inc
|
||||
from libcpp.utility cimport pair
|
||||
from libcpp.map cimport map as cpp_map
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typedefs cimport float64_t, intp_t
|
||||
|
||||
|
||||
###############################################################################
|
||||
# An object to be used in Python
|
||||
|
||||
# Lookup is faster than dict (up to 10 times), and so is full traversal
|
||||
# (up to 50 times), and assignment (up to 6 times), but creation is
|
||||
# slower (up to 3 times). Also, a large benefit is that memory
|
||||
# consumption is reduced a lot compared to a Python dict
|
||||
|
||||
cdef class IntFloatDict:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intp_t[:] keys,
|
||||
float64_t[:] values,
|
||||
):
|
||||
cdef int i
|
||||
cdef int size = values.size
|
||||
# Should check that sizes for keys and values are equal, and
|
||||
# after should boundcheck(False)
|
||||
for i in range(size):
|
||||
self.my_map[keys[i]] = values[i]
|
||||
|
||||
def __len__(self):
|
||||
return self.my_map.size()
|
||||
|
||||
def __getitem__(self, int key):
|
||||
cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.find(key)
|
||||
if it == self.my_map.end():
|
||||
# The key is not in the dict
|
||||
raise KeyError('%i' % key)
|
||||
return deref(it).second
|
||||
|
||||
def __setitem__(self, int key, float value):
|
||||
self.my_map[key] = value
|
||||
|
||||
# Cython 0.20 generates buggy code below. Commenting this out for now
|
||||
# and relying on the to_arrays method
|
||||
# def __iter__(self):
|
||||
# cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.begin()
|
||||
# cdef cpp_map[intp_t, float64_t].iterator end = self.my_map.end()
|
||||
# while it != end:
|
||||
# yield deref(it).first, deref(it).second
|
||||
# inc(it)
|
||||
|
||||
def __iter__(self):
|
||||
cdef int size = self.my_map.size()
|
||||
cdef intp_t [:] keys = np.empty(size, dtype=np.intp)
|
||||
cdef float64_t [:] values = np.empty(size, dtype=np.float64)
|
||||
self._to_arrays(keys, values)
|
||||
cdef int idx
|
||||
cdef intp_t key
|
||||
cdef float64_t value
|
||||
for idx in range(size):
|
||||
key = keys[idx]
|
||||
value = values[idx]
|
||||
yield key, value
|
||||
|
||||
def to_arrays(self):
|
||||
"""Return the key, value representation of the IntFloatDict
|
||||
object.
|
||||
|
||||
Returns
|
||||
=======
|
||||
keys : ndarray, shape (n_items, ), dtype=int
|
||||
The indices of the data points
|
||||
values : ndarray, shape (n_items, ), dtype=float
|
||||
The values of the data points
|
||||
"""
|
||||
cdef int size = self.my_map.size()
|
||||
keys = np.empty(size, dtype=np.intp)
|
||||
values = np.empty(size, dtype=np.float64)
|
||||
self._to_arrays(keys, values)
|
||||
return keys, values
|
||||
|
||||
cdef _to_arrays(self, intp_t [:] keys, float64_t [:] values):
|
||||
# Internal version of to_arrays that takes already-initialized arrays
|
||||
cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.begin()
|
||||
cdef cpp_map[intp_t, float64_t].iterator end = self.my_map.end()
|
||||
cdef int index = 0
|
||||
while it != end:
|
||||
keys[index] = deref(it).first
|
||||
values[index] = deref(it).second
|
||||
inc(it)
|
||||
index += 1
|
||||
|
||||
def update(self, IntFloatDict other):
|
||||
cdef cpp_map[intp_t, float64_t].iterator it = other.my_map.begin()
|
||||
cdef cpp_map[intp_t, float64_t].iterator end = other.my_map.end()
|
||||
while it != end:
|
||||
self.my_map[deref(it).first] = deref(it).second
|
||||
inc(it)
|
||||
|
||||
def copy(self):
|
||||
cdef IntFloatDict out_obj = IntFloatDict.__new__(IntFloatDict)
|
||||
# The '=' operator is a copy operator for C++ maps
|
||||
out_obj.my_map = self.my_map
|
||||
return out_obj
|
||||
|
||||
def append(self, intp_t key, float64_t value):
|
||||
# Construct our arguments
|
||||
cdef pair[intp_t, float64_t] args
|
||||
args.first = key
|
||||
args.second = value
|
||||
self.my_map.insert(args)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# operation on dict
|
||||
|
||||
def argmin(IntFloatDict d):
|
||||
cdef cpp_map[intp_t, float64_t].iterator it = d.my_map.begin()
|
||||
cdef cpp_map[intp_t, float64_t].iterator end = d.my_map.end()
|
||||
cdef intp_t min_key = -1
|
||||
cdef float64_t min_value = np.inf
|
||||
while it != end:
|
||||
if deref(it).second < min_value:
|
||||
min_value = deref(it).second
|
||||
min_key = deref(it).first
|
||||
inc(it)
|
||||
return min_key, min_value
|
||||
Binary file not shown.
14
venv/lib/python3.12/site-packages/sklearn/utils/_heap.pxd
Normal file
14
venv/lib/python3.12/site-packages/sklearn/utils/_heap.pxd
Normal file
@@ -0,0 +1,14 @@
|
||||
# Heap routines, used in various Cython implementations.
|
||||
|
||||
from cython cimport floating
|
||||
|
||||
from ._typedefs cimport intp_t
|
||||
|
||||
|
||||
cdef int heap_push(
|
||||
floating* values,
|
||||
intp_t* indices,
|
||||
intp_t size,
|
||||
floating val,
|
||||
intp_t val_idx,
|
||||
) noexcept nogil
|
||||
85
venv/lib/python3.12/site-packages/sklearn/utils/_heap.pyx
Normal file
85
venv/lib/python3.12/site-packages/sklearn/utils/_heap.pyx
Normal file
@@ -0,0 +1,85 @@
|
||||
from cython cimport floating
|
||||
|
||||
from ._typedefs cimport intp_t
|
||||
|
||||
|
||||
cdef inline int heap_push(
|
||||
floating* values,
|
||||
intp_t* indices,
|
||||
intp_t size,
|
||||
floating val,
|
||||
intp_t val_idx,
|
||||
) noexcept nogil:
|
||||
"""Push a tuple (val, val_idx) onto a fixed-size max-heap.
|
||||
|
||||
The max-heap is represented as a Structure of Arrays where:
|
||||
- values is the array containing the data to construct the heap with
|
||||
- indices is the array containing the indices (meta-data) of each value
|
||||
|
||||
Notes
|
||||
-----
|
||||
Arrays are manipulated via a pointer to there first element and their size
|
||||
as to ease the processing of dynamically allocated buffers.
|
||||
|
||||
For instance, in pseudo-code:
|
||||
|
||||
values = [1.2, 0.4, 0.1],
|
||||
indices = [42, 1, 5],
|
||||
heap_push(
|
||||
values=values,
|
||||
indices=indices,
|
||||
size=3,
|
||||
val=0.2,
|
||||
val_idx=4,
|
||||
)
|
||||
|
||||
will modify values and indices inplace, giving at the end of the call:
|
||||
|
||||
values == [0.4, 0.2, 0.1]
|
||||
indices == [1, 4, 5]
|
||||
|
||||
"""
|
||||
cdef:
|
||||
intp_t current_idx, left_child_idx, right_child_idx, swap_idx
|
||||
|
||||
# Check if val should be in heap
|
||||
if val >= values[0]:
|
||||
return 0
|
||||
|
||||
# Insert val at position zero
|
||||
values[0] = val
|
||||
indices[0] = val_idx
|
||||
|
||||
# Descend the heap, swapping values until the max heap criterion is met
|
||||
current_idx = 0
|
||||
while True:
|
||||
left_child_idx = 2 * current_idx + 1
|
||||
right_child_idx = left_child_idx + 1
|
||||
|
||||
if left_child_idx >= size:
|
||||
break
|
||||
elif right_child_idx >= size:
|
||||
if values[left_child_idx] > val:
|
||||
swap_idx = left_child_idx
|
||||
else:
|
||||
break
|
||||
elif values[left_child_idx] >= values[right_child_idx]:
|
||||
if val < values[left_child_idx]:
|
||||
swap_idx = left_child_idx
|
||||
else:
|
||||
break
|
||||
else:
|
||||
if val < values[right_child_idx]:
|
||||
swap_idx = right_child_idx
|
||||
else:
|
||||
break
|
||||
|
||||
values[current_idx] = values[swap_idx]
|
||||
indices[current_idx] = indices[swap_idx]
|
||||
|
||||
current_idx = swap_idx
|
||||
|
||||
values[current_idx] = val
|
||||
indices[current_idx] = val_idx
|
||||
|
||||
return 0
|
||||
755
venv/lib/python3.12/site-packages/sklearn/utils/_indexing.py
Normal file
755
venv/lib/python3.12/site-packages/sklearn/utils/_indexing.py
Normal file
@@ -0,0 +1,755 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import numbers
|
||||
import sys
|
||||
import warnings
|
||||
from collections import UserList
|
||||
from itertools import compress, islice
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse import issparse
|
||||
|
||||
from sklearn.utils.fixes import PYARROW_VERSION_BELOW_17
|
||||
|
||||
from ._array_api import _is_numpy_namespace, get_namespace
|
||||
from ._param_validation import Interval, validate_params
|
||||
from .extmath import _approximate_mode
|
||||
from .validation import (
|
||||
_check_sample_weight,
|
||||
_is_arraylike_not_scalar,
|
||||
_is_pandas_df,
|
||||
_is_polars_df_or_series,
|
||||
_is_pyarrow_data,
|
||||
_use_interchange_protocol,
|
||||
check_array,
|
||||
check_consistent_length,
|
||||
check_random_state,
|
||||
)
|
||||
|
||||
|
||||
def _array_indexing(array, key, key_dtype, axis):
|
||||
"""Index an array or scipy.sparse consistently across NumPy version."""
|
||||
xp, is_array_api = get_namespace(array)
|
||||
if is_array_api:
|
||||
return xp.take(array, key, axis=axis)
|
||||
if issparse(array) and key_dtype == "bool":
|
||||
key = np.asarray(key)
|
||||
if isinstance(key, tuple):
|
||||
key = list(key)
|
||||
return array[key, ...] if axis == 0 else array[:, key]
|
||||
|
||||
|
||||
def _pandas_indexing(X, key, key_dtype, axis):
|
||||
"""Index a pandas dataframe or a series."""
|
||||
if _is_arraylike_not_scalar(key):
|
||||
key = np.asarray(key)
|
||||
|
||||
if key_dtype == "int" and not (isinstance(key, slice) or np.isscalar(key)):
|
||||
# using take() instead of iloc[] ensures the return value is a "proper"
|
||||
# copy that will not raise SettingWithCopyWarning
|
||||
return X.take(key, axis=axis)
|
||||
else:
|
||||
# check whether we should index with loc or iloc
|
||||
indexer = X.iloc if key_dtype == "int" else X.loc
|
||||
return indexer[:, key] if axis else indexer[key]
|
||||
|
||||
|
||||
def _list_indexing(X, key, key_dtype):
|
||||
"""Index a Python list."""
|
||||
if np.isscalar(key) or isinstance(key, slice):
|
||||
# key is a slice or a scalar
|
||||
return X[key]
|
||||
if key_dtype == "bool":
|
||||
# key is a boolean array-like
|
||||
return list(compress(X, key))
|
||||
# key is a integer array-like of key
|
||||
return [X[idx] for idx in key]
|
||||
|
||||
|
||||
def _polars_indexing(X, key, key_dtype, axis):
|
||||
"""Index a polars dataframe or series."""
|
||||
# Polars behavior is more consistent with lists
|
||||
if isinstance(key, np.ndarray):
|
||||
# Convert each element of the array to a Python scalar
|
||||
key = key.tolist()
|
||||
elif not (np.isscalar(key) or isinstance(key, slice)):
|
||||
key = list(key)
|
||||
|
||||
if axis == 1:
|
||||
# Here we are certain to have a polars DataFrame; which can be indexed with
|
||||
# integer and string scalar, and list of integer, string and boolean
|
||||
return X[:, key]
|
||||
|
||||
if key_dtype == "bool":
|
||||
# Boolean mask can be indexed in the same way for Series and DataFrame (axis=0)
|
||||
return X.filter(key)
|
||||
|
||||
# Integer scalar and list of integer can be indexed in the same way for Series and
|
||||
# DataFrame (axis=0)
|
||||
X_indexed = X[key]
|
||||
if np.isscalar(key) and len(X.shape) == 2:
|
||||
# `X_indexed` is a DataFrame with a single row; we return a Series to be
|
||||
# consistent with pandas
|
||||
pl = sys.modules["polars"]
|
||||
return pl.Series(X_indexed.row(0))
|
||||
return X_indexed
|
||||
|
||||
|
||||
def _pyarrow_indexing(X, key, key_dtype, axis):
|
||||
"""Index a pyarrow data."""
|
||||
scalar_key = np.isscalar(key)
|
||||
if isinstance(key, slice):
|
||||
if isinstance(key.stop, str):
|
||||
start = X.column_names.index(key.start)
|
||||
stop = X.column_names.index(key.stop) + 1
|
||||
else:
|
||||
start = 0 if not key.start else key.start
|
||||
stop = key.stop
|
||||
step = 1 if not key.step else key.step
|
||||
key = list(range(start, stop, step))
|
||||
|
||||
if axis == 1:
|
||||
# Here we are certain that X is a pyarrow Table or RecordBatch.
|
||||
if key_dtype == "int" and not isinstance(key, list):
|
||||
# pyarrow's X.select behavior is more consistent with integer lists.
|
||||
key = np.asarray(key).tolist()
|
||||
if key_dtype == "bool":
|
||||
key = np.asarray(key).nonzero()[0].tolist()
|
||||
|
||||
if scalar_key:
|
||||
return X.column(key)
|
||||
|
||||
return X.select(key)
|
||||
|
||||
# axis == 0 from here on
|
||||
if scalar_key:
|
||||
if hasattr(X, "shape"):
|
||||
# X is a Table or RecordBatch
|
||||
key = [key]
|
||||
else:
|
||||
return X[key].as_py()
|
||||
elif not isinstance(key, list):
|
||||
key = np.asarray(key)
|
||||
|
||||
if key_dtype == "bool":
|
||||
# TODO(pyarrow): remove version checking and following if-branch when
|
||||
# pyarrow==17.0.0 is the minimal version, see pyarrow issue
|
||||
# https://github.com/apache/arrow/issues/42013 for more info
|
||||
if PYARROW_VERSION_BELOW_17:
|
||||
import pyarrow
|
||||
|
||||
if not isinstance(key, pyarrow.BooleanArray):
|
||||
key = pyarrow.array(key, type=pyarrow.bool_())
|
||||
|
||||
X_indexed = X.filter(key)
|
||||
|
||||
else:
|
||||
X_indexed = X.take(key)
|
||||
|
||||
if scalar_key and len(getattr(X, "shape", [0])) == 2:
|
||||
# X_indexed is a dataframe-like with a single row; we return a Series to be
|
||||
# consistent with pandas
|
||||
pa = sys.modules["pyarrow"]
|
||||
return pa.array(X_indexed.to_pylist()[0].values())
|
||||
return X_indexed
|
||||
|
||||
|
||||
def _determine_key_type(key, accept_slice=True):
|
||||
"""Determine the data type of key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : scalar, slice or array-like
|
||||
The key from which we want to infer the data type.
|
||||
|
||||
accept_slice : bool, default=True
|
||||
Whether or not to raise an error if the key is a slice.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtype : {'int', 'str', 'bool', None}
|
||||
Returns the data type of key.
|
||||
"""
|
||||
err_msg = (
|
||||
"No valid specification of the columns. Only a scalar, list or "
|
||||
"slice of all integers or all strings, or boolean mask is "
|
||||
"allowed"
|
||||
)
|
||||
|
||||
dtype_to_str = {int: "int", str: "str", bool: "bool", np.bool_: "bool"}
|
||||
array_dtype_to_str = {
|
||||
"i": "int",
|
||||
"u": "int",
|
||||
"b": "bool",
|
||||
"O": "str",
|
||||
"U": "str",
|
||||
"S": "str",
|
||||
}
|
||||
|
||||
if key is None:
|
||||
return None
|
||||
if isinstance(key, tuple(dtype_to_str.keys())):
|
||||
try:
|
||||
return dtype_to_str[type(key)]
|
||||
except KeyError:
|
||||
raise ValueError(err_msg)
|
||||
if isinstance(key, slice):
|
||||
if not accept_slice:
|
||||
raise TypeError(
|
||||
"Only array-like or scalar are supported. A Python slice was given."
|
||||
)
|
||||
if key.start is None and key.stop is None:
|
||||
return None
|
||||
key_start_type = _determine_key_type(key.start)
|
||||
key_stop_type = _determine_key_type(key.stop)
|
||||
if key_start_type is not None and key_stop_type is not None:
|
||||
if key_start_type != key_stop_type:
|
||||
raise ValueError(err_msg)
|
||||
if key_start_type is not None:
|
||||
return key_start_type
|
||||
return key_stop_type
|
||||
# TODO(1.9) remove UserList when the force_int_remainder_cols param
|
||||
# of ColumnTransformer is removed
|
||||
if isinstance(key, (list, tuple, UserList)):
|
||||
unique_key = set(key)
|
||||
key_type = {_determine_key_type(elt) for elt in unique_key}
|
||||
if not key_type:
|
||||
return None
|
||||
if len(key_type) != 1:
|
||||
raise ValueError(err_msg)
|
||||
return key_type.pop()
|
||||
if hasattr(key, "dtype"):
|
||||
xp, is_array_api = get_namespace(key)
|
||||
# NumPy arrays are special-cased in their own branch because the Array API
|
||||
# cannot handle object/string-based dtypes that are often used to index
|
||||
# columns of dataframes by names.
|
||||
if is_array_api and not _is_numpy_namespace(xp):
|
||||
if xp.isdtype(key.dtype, "bool"):
|
||||
return "bool"
|
||||
elif xp.isdtype(key.dtype, "integral"):
|
||||
return "int"
|
||||
else:
|
||||
raise ValueError(err_msg)
|
||||
else:
|
||||
try:
|
||||
return array_dtype_to_str[key.dtype.kind]
|
||||
except KeyError:
|
||||
raise ValueError(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def _safe_indexing(X, indices, *, axis=0):
|
||||
"""Return rows, items or columns of X using indices.
|
||||
|
||||
.. warning::
|
||||
|
||||
This utility is documented, but **private**. This means that
|
||||
backward compatibility might be broken without any deprecation
|
||||
cycle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series
|
||||
Data from which to sample rows, items or columns. `list` are only
|
||||
supported when `axis=0`.
|
||||
indices : bool, int, str, slice, array-like
|
||||
- If `axis=0`, boolean and integer array-like, integer slice,
|
||||
and scalar integer are supported.
|
||||
- If `axis=1`:
|
||||
- to select a single column, `indices` can be of `int` type for
|
||||
all `X` types and `str` only for dataframe. The selected subset
|
||||
will be 1D, unless `X` is a sparse matrix in which case it will
|
||||
be 2D.
|
||||
- to select multiples columns, `indices` can be one of the
|
||||
following: `list`, `array`, `slice`. The type used in
|
||||
these containers can be one of the following: `int`, 'bool' and
|
||||
`str`. However, `str` is only supported when `X` is a dataframe.
|
||||
The selected subset will be 2D.
|
||||
axis : int, default=0
|
||||
The axis along which `X` will be subsampled. `axis=0` will select
|
||||
rows while `axis=1` will select columns.
|
||||
|
||||
Returns
|
||||
-------
|
||||
subset
|
||||
Subset of X on axis 0 or 1.
|
||||
|
||||
Notes
|
||||
-----
|
||||
CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are
|
||||
not supported.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from sklearn.utils import _safe_indexing
|
||||
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
|
||||
>>> _safe_indexing(data, 0, axis=0) # select the first row
|
||||
array([1, 2])
|
||||
>>> _safe_indexing(data, 0, axis=1) # select the first column
|
||||
array([1, 3, 5])
|
||||
"""
|
||||
if indices is None:
|
||||
return X
|
||||
|
||||
if axis not in (0, 1):
|
||||
raise ValueError(
|
||||
"'axis' should be either 0 (to index rows) or 1 (to index "
|
||||
" column). Got {} instead.".format(axis)
|
||||
)
|
||||
|
||||
indices_dtype = _determine_key_type(indices)
|
||||
|
||||
if axis == 0 and indices_dtype == "str":
|
||||
raise ValueError("String indexing is not supported with 'axis=0'")
|
||||
|
||||
if axis == 1 and isinstance(X, list):
|
||||
raise ValueError("axis=1 is not supported for lists")
|
||||
|
||||
if axis == 1 and (ndim := len(getattr(X, "shape", [0]))) != 2:
|
||||
raise ValueError(
|
||||
"'X' should be a 2D NumPy array, 2D sparse matrix or "
|
||||
"dataframe when indexing the columns (i.e. 'axis=1'). "
|
||||
f"Got {type(X)} instead with {ndim} dimension(s)."
|
||||
)
|
||||
|
||||
if (
|
||||
axis == 1
|
||||
and indices_dtype == "str"
|
||||
and not (_is_pandas_df(X) or _use_interchange_protocol(X))
|
||||
):
|
||||
raise ValueError(
|
||||
"Specifying the columns using strings is only supported for dataframes."
|
||||
)
|
||||
|
||||
if hasattr(X, "iloc"):
|
||||
# TODO: we should probably use _is_pandas_df_or_series(X) instead but:
|
||||
# 1) Currently, it (probably) works for dataframes compliant to pandas' API.
|
||||
# 2) Updating would require updating some tests such as
|
||||
# test_train_test_split_mock_pandas.
|
||||
return _pandas_indexing(X, indices, indices_dtype, axis=axis)
|
||||
elif _is_polars_df_or_series(X):
|
||||
return _polars_indexing(X, indices, indices_dtype, axis=axis)
|
||||
elif _is_pyarrow_data(X):
|
||||
return _pyarrow_indexing(X, indices, indices_dtype, axis=axis)
|
||||
elif _use_interchange_protocol(X): # pragma: no cover
|
||||
# Once the dataframe X is converted into its dataframe interchange protocol
|
||||
# version by calling X.__dataframe__(), it becomes very hard to turn it back
|
||||
# into its original type, e.g., a pyarrow.Table, see
|
||||
# https://github.com/data-apis/dataframe-api/issues/85.
|
||||
raise warnings.warn(
|
||||
message="A data object with support for the dataframe interchange protocol"
|
||||
"was passed, but scikit-learn does currently not know how to handle this "
|
||||
"kind of data. Some array/list indexing will be tried.",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
if hasattr(X, "shape"):
|
||||
return _array_indexing(X, indices, indices_dtype, axis=axis)
|
||||
else:
|
||||
return _list_indexing(X, indices, indices_dtype)
|
||||
|
||||
|
||||
def _safe_assign(X, values, *, row_indexer=None, column_indexer=None):
|
||||
"""Safe assignment to a numpy array, sparse matrix, or pandas dataframe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {ndarray, sparse-matrix, dataframe}
|
||||
Array to be modified. It is expected to be 2-dimensional.
|
||||
|
||||
values : ndarray
|
||||
The values to be assigned to `X`.
|
||||
|
||||
row_indexer : array-like, dtype={int, bool}, default=None
|
||||
A 1-dimensional array to select the rows of interest. If `None`, all
|
||||
rows are selected.
|
||||
|
||||
column_indexer : array-like, dtype={int, bool}, default=None
|
||||
A 1-dimensional array to select the columns of interest. If `None`, all
|
||||
columns are selected.
|
||||
"""
|
||||
row_indexer = slice(None, None, None) if row_indexer is None else row_indexer
|
||||
column_indexer = (
|
||||
slice(None, None, None) if column_indexer is None else column_indexer
|
||||
)
|
||||
|
||||
if hasattr(X, "iloc"): # pandas dataframe
|
||||
with warnings.catch_warnings():
|
||||
# pandas >= 1.5 raises a warning when using iloc to set values in a column
|
||||
# that does not have the same type as the column being set. It happens
|
||||
# for instance when setting a categorical column with a string.
|
||||
# In the future the behavior won't change and the warning should disappear.
|
||||
# TODO(1.3): check if the warning is still raised or remove the filter.
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
X.iloc[row_indexer, column_indexer] = values
|
||||
else: # numpy array or sparse matrix
|
||||
X[row_indexer, column_indexer] = values
|
||||
|
||||
|
||||
def _get_column_indices_for_bool_or_int(key, n_columns):
|
||||
# Convert key into list of positive integer indexes
|
||||
try:
|
||||
idx = _safe_indexing(np.arange(n_columns), key)
|
||||
except IndexError as e:
|
||||
raise ValueError(
|
||||
f"all features must be in [0, {n_columns - 1}] or [-{n_columns}, 0]"
|
||||
) from e
|
||||
return np.atleast_1d(idx).tolist()
|
||||
|
||||
|
||||
def _get_column_indices(X, key):
|
||||
"""Get feature column indices for input data X and key.
|
||||
|
||||
For accepted values of `key`, see the docstring of
|
||||
:func:`_safe_indexing`.
|
||||
"""
|
||||
key_dtype = _determine_key_type(key)
|
||||
if _use_interchange_protocol(X):
|
||||
return _get_column_indices_interchange(X.__dataframe__(), key, key_dtype)
|
||||
|
||||
n_columns = X.shape[1]
|
||||
if isinstance(key, (list, tuple)) and not key:
|
||||
# we get an empty list
|
||||
return []
|
||||
elif key_dtype in ("bool", "int"):
|
||||
return _get_column_indices_for_bool_or_int(key, n_columns)
|
||||
else:
|
||||
try:
|
||||
all_columns = X.columns
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Specifying the columns using strings is only supported for dataframes."
|
||||
)
|
||||
if isinstance(key, str):
|
||||
columns = [key]
|
||||
elif isinstance(key, slice):
|
||||
start, stop = key.start, key.stop
|
||||
if start is not None:
|
||||
start = all_columns.get_loc(start)
|
||||
if stop is not None:
|
||||
# pandas indexing with strings is endpoint included
|
||||
stop = all_columns.get_loc(stop) + 1
|
||||
else:
|
||||
stop = n_columns + 1
|
||||
return list(islice(range(n_columns), start, stop))
|
||||
else:
|
||||
columns = list(key)
|
||||
|
||||
try:
|
||||
column_indices = []
|
||||
for col in columns:
|
||||
col_idx = all_columns.get_loc(col)
|
||||
if not isinstance(col_idx, numbers.Integral):
|
||||
raise ValueError(
|
||||
f"Selected columns, {columns}, are not unique in dataframe"
|
||||
)
|
||||
column_indices.append(col_idx)
|
||||
|
||||
except KeyError as e:
|
||||
raise ValueError("A given column is not a column of the dataframe") from e
|
||||
|
||||
return column_indices
|
||||
|
||||
|
||||
def _get_column_indices_interchange(X_interchange, key, key_dtype):
|
||||
"""Same as _get_column_indices but for X with __dataframe__ protocol."""
|
||||
|
||||
n_columns = X_interchange.num_columns()
|
||||
|
||||
if isinstance(key, (list, tuple)) and not key:
|
||||
# we get an empty list
|
||||
return []
|
||||
elif key_dtype in ("bool", "int"):
|
||||
return _get_column_indices_for_bool_or_int(key, n_columns)
|
||||
else:
|
||||
column_names = list(X_interchange.column_names())
|
||||
|
||||
if isinstance(key, slice):
|
||||
if key.step not in [1, None]:
|
||||
raise NotImplementedError("key.step must be 1 or None")
|
||||
start, stop = key.start, key.stop
|
||||
if start is not None:
|
||||
start = column_names.index(start)
|
||||
|
||||
if stop is not None:
|
||||
stop = column_names.index(stop) + 1
|
||||
else:
|
||||
stop = n_columns + 1
|
||||
return list(islice(range(n_columns), start, stop))
|
||||
|
||||
selected_columns = [key] if np.isscalar(key) else key
|
||||
|
||||
try:
|
||||
return [column_names.index(col) for col in selected_columns]
|
||||
except ValueError as e:
|
||||
raise ValueError("A given column is not a column of the dataframe") from e
|
||||
|
||||
|
||||
@validate_params(
|
||||
{
|
||||
"replace": ["boolean"],
|
||||
"n_samples": [Interval(numbers.Integral, 1, None, closed="left"), None],
|
||||
"random_state": ["random_state"],
|
||||
"stratify": ["array-like", "sparse matrix", None],
|
||||
"sample_weight": ["array-like", None],
|
||||
},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def resample(
|
||||
*arrays,
|
||||
replace=True,
|
||||
n_samples=None,
|
||||
random_state=None,
|
||||
stratify=None,
|
||||
sample_weight=None,
|
||||
):
|
||||
"""Resample arrays or sparse matrices in a consistent way.
|
||||
|
||||
The default strategy implements one step of the bootstrapping
|
||||
procedure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*arrays : sequence of array-like of shape (n_samples,) or \
|
||||
(n_samples, n_outputs)
|
||||
Indexable data-structures can be arrays, lists, dataframes or scipy
|
||||
sparse matrices with consistent first dimension.
|
||||
|
||||
replace : bool, default=True
|
||||
Implements resampling with replacement. It must be set to True
|
||||
whenever sampling with non-uniform weights: a few data points with very large
|
||||
weights are expected to be sampled several times with probability to preserve
|
||||
the distribution induced by the weights. If False, this will implement
|
||||
(sliced) random permutations.
|
||||
|
||||
n_samples : int, default=None
|
||||
Number of samples to generate. If left to None this is
|
||||
automatically set to the first dimension of the arrays.
|
||||
If replace is False it should not be larger than the length of
|
||||
arrays.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
Determines random number generation for shuffling
|
||||
the data.
|
||||
Pass an int for reproducible results across multiple function calls.
|
||||
See :term:`Glossary <random_state>`.
|
||||
|
||||
stratify : {array-like, sparse matrix} of shape (n_samples,) or \
|
||||
(n_samples, n_outputs), default=None
|
||||
If not None, data is split in a stratified fashion, using this as
|
||||
the class labels.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Contains weight values to be associated with each sample. Values are
|
||||
normalized to sum to one and interpreted as probability for sampling
|
||||
each data point.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
Returns
|
||||
-------
|
||||
resampled_arrays : sequence of array-like of shape (n_samples,) or \
|
||||
(n_samples, n_outputs)
|
||||
Sequence of resampled copies of the collections. The original arrays
|
||||
are not impacted.
|
||||
|
||||
See Also
|
||||
--------
|
||||
shuffle : Shuffle arrays or sparse matrices in a consistent way.
|
||||
|
||||
Examples
|
||||
--------
|
||||
It is possible to mix sparse and dense arrays in the same run::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
|
||||
>>> y = np.array([0, 1, 2])
|
||||
|
||||
>>> from scipy.sparse import coo_matrix
|
||||
>>> X_sparse = coo_matrix(X)
|
||||
|
||||
>>> from sklearn.utils import resample
|
||||
>>> X, X_sparse, y = resample(X, X_sparse, y, random_state=0)
|
||||
>>> X
|
||||
array([[1., 0.],
|
||||
[2., 1.],
|
||||
[1., 0.]])
|
||||
|
||||
>>> X_sparse
|
||||
<Compressed Sparse Row sparse matrix of dtype 'float64'
|
||||
with 4 stored elements and shape (3, 2)>
|
||||
|
||||
>>> X_sparse.toarray()
|
||||
array([[1., 0.],
|
||||
[2., 1.],
|
||||
[1., 0.]])
|
||||
|
||||
>>> y
|
||||
array([0, 1, 0])
|
||||
|
||||
>>> resample(y, n_samples=2, random_state=0)
|
||||
array([0, 1])
|
||||
|
||||
Example using stratification::
|
||||
|
||||
>>> y = [0, 0, 1, 1, 1, 1, 1, 1, 1]
|
||||
>>> resample(y, n_samples=5, replace=False, stratify=y,
|
||||
... random_state=0)
|
||||
[1, 1, 1, 0, 1]
|
||||
"""
|
||||
max_n_samples = n_samples
|
||||
random_state = check_random_state(random_state)
|
||||
|
||||
if len(arrays) == 0:
|
||||
return None
|
||||
|
||||
first = arrays[0]
|
||||
n_samples = first.shape[0] if hasattr(first, "shape") else len(first)
|
||||
|
||||
if max_n_samples is None:
|
||||
max_n_samples = n_samples
|
||||
elif (max_n_samples > n_samples) and (not replace):
|
||||
raise ValueError(
|
||||
"Cannot sample %d out of arrays with dim %d when replace is False"
|
||||
% (max_n_samples, n_samples)
|
||||
)
|
||||
|
||||
check_consistent_length(*arrays)
|
||||
|
||||
if sample_weight is not None and not replace:
|
||||
raise NotImplementedError(
|
||||
"Resampling with sample_weight is only implemented for replace=True."
|
||||
)
|
||||
if sample_weight is not None and stratify is not None:
|
||||
raise NotImplementedError(
|
||||
"Resampling with sample_weight is only implemented for stratify=None."
|
||||
)
|
||||
if stratify is None:
|
||||
if replace:
|
||||
if sample_weight is not None:
|
||||
sample_weight = _check_sample_weight(
|
||||
sample_weight, first, dtype=np.float64
|
||||
)
|
||||
p = sample_weight / sample_weight.sum()
|
||||
else:
|
||||
p = None
|
||||
indices = random_state.choice(
|
||||
n_samples,
|
||||
size=max_n_samples,
|
||||
p=p,
|
||||
replace=True,
|
||||
)
|
||||
else:
|
||||
indices = np.arange(n_samples)
|
||||
random_state.shuffle(indices)
|
||||
indices = indices[:max_n_samples]
|
||||
else:
|
||||
# Code adapted from StratifiedShuffleSplit()
|
||||
y = check_array(stratify, ensure_2d=False, dtype=None)
|
||||
if y.ndim == 2:
|
||||
# for multi-label y, map each distinct row to a string repr
|
||||
# using join because str(row) uses an ellipsis if len(row) > 1000
|
||||
y = np.array([" ".join(row.astype("str")) for row in y])
|
||||
|
||||
classes, y_indices = np.unique(y, return_inverse=True)
|
||||
n_classes = classes.shape[0]
|
||||
|
||||
class_counts = np.bincount(y_indices)
|
||||
|
||||
# Find the sorted list of instances for each class:
|
||||
# (np.unique above performs a sort, so code is O(n logn) already)
|
||||
class_indices = np.split(
|
||||
np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
|
||||
)
|
||||
|
||||
n_i = _approximate_mode(class_counts, max_n_samples, random_state)
|
||||
|
||||
indices = []
|
||||
|
||||
for i in range(n_classes):
|
||||
indices_i = random_state.choice(class_indices[i], n_i[i], replace=replace)
|
||||
indices.extend(indices_i)
|
||||
|
||||
indices = random_state.permutation(indices)
|
||||
|
||||
# convert sparse matrices to CSR for row-based indexing
|
||||
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
|
||||
resampled_arrays = [_safe_indexing(a, indices) for a in arrays]
|
||||
if len(resampled_arrays) == 1:
|
||||
# syntactic sugar for the unit argument case
|
||||
return resampled_arrays[0]
|
||||
else:
|
||||
return resampled_arrays
|
||||
|
||||
|
||||
def shuffle(*arrays, random_state=None, n_samples=None):
|
||||
"""Shuffle arrays or sparse matrices in a consistent way.
|
||||
|
||||
This is a convenience alias to ``resample(*arrays, replace=False)`` to do
|
||||
random permutations of the collections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*arrays : sequence of indexable data-structures
|
||||
Indexable data-structures can be arrays, lists, dataframes or scipy
|
||||
sparse matrices with consistent first dimension.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
Determines random number generation for shuffling
|
||||
the data.
|
||||
Pass an int for reproducible results across multiple function calls.
|
||||
See :term:`Glossary <random_state>`.
|
||||
|
||||
n_samples : int, default=None
|
||||
Number of samples to generate. If left to None this is
|
||||
automatically set to the first dimension of the arrays. It should
|
||||
not be larger than the length of arrays.
|
||||
|
||||
Returns
|
||||
-------
|
||||
shuffled_arrays : sequence of indexable data-structures
|
||||
Sequence of shuffled copies of the collections. The original arrays
|
||||
are not impacted.
|
||||
|
||||
See Also
|
||||
--------
|
||||
resample : Resample arrays or sparse matrices in a consistent way.
|
||||
|
||||
Examples
|
||||
--------
|
||||
It is possible to mix sparse and dense arrays in the same run::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
|
||||
>>> y = np.array([0, 1, 2])
|
||||
|
||||
>>> from scipy.sparse import coo_matrix
|
||||
>>> X_sparse = coo_matrix(X)
|
||||
|
||||
>>> from sklearn.utils import shuffle
|
||||
>>> X, X_sparse, y = shuffle(X, X_sparse, y, random_state=0)
|
||||
>>> X
|
||||
array([[0., 0.],
|
||||
[2., 1.],
|
||||
[1., 0.]])
|
||||
|
||||
>>> X_sparse
|
||||
<Compressed Sparse Row sparse matrix of dtype 'float64'
|
||||
with 3 stored elements and shape (3, 2)>
|
||||
|
||||
>>> X_sparse.toarray()
|
||||
array([[0., 0.],
|
||||
[2., 1.],
|
||||
[1., 0.]])
|
||||
|
||||
>>> y
|
||||
array([2, 1, 0])
|
||||
|
||||
>>> shuffle(y, n_samples=2, random_state=0)
|
||||
array([0, 1])
|
||||
"""
|
||||
return resample(
|
||||
*arrays, replace=False, n_samples=n_samples, random_state=random_state
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,51 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from libc.math cimport isnan, isinf
|
||||
from cython cimport floating
|
||||
|
||||
|
||||
cpdef enum FiniteStatus:
|
||||
all_finite = 0
|
||||
has_nan = 1
|
||||
has_infinite = 2
|
||||
|
||||
|
||||
def cy_isfinite(floating[::1] a, bint allow_nan=False):
|
||||
cdef FiniteStatus result
|
||||
with nogil:
|
||||
result = _isfinite(a, allow_nan)
|
||||
return result
|
||||
|
||||
|
||||
cdef inline FiniteStatus _isfinite(floating[::1] a, bint allow_nan) noexcept nogil:
|
||||
cdef floating* a_ptr = &a[0]
|
||||
cdef Py_ssize_t length = len(a)
|
||||
if allow_nan:
|
||||
return _isfinite_allow_nan(a_ptr, length)
|
||||
else:
|
||||
return _isfinite_disable_nan(a_ptr, length)
|
||||
|
||||
|
||||
cdef inline FiniteStatus _isfinite_allow_nan(floating* a_ptr,
|
||||
Py_ssize_t length) noexcept nogil:
|
||||
cdef Py_ssize_t i
|
||||
cdef floating v
|
||||
for i in range(length):
|
||||
v = a_ptr[i]
|
||||
if isinf(v):
|
||||
return FiniteStatus.has_infinite
|
||||
return FiniteStatus.all_finite
|
||||
|
||||
|
||||
cdef inline FiniteStatus _isfinite_disable_nan(floating* a_ptr,
|
||||
Py_ssize_t length) noexcept nogil:
|
||||
cdef Py_ssize_t i
|
||||
cdef floating v
|
||||
for i in range(length):
|
||||
v = a_ptr[i]
|
||||
if isnan(v):
|
||||
return FiniteStatus.has_nan
|
||||
elif isinf(v):
|
||||
return FiniteStatus.has_infinite
|
||||
return FiniteStatus.all_finite
|
||||
181
venv/lib/python3.12/site-packages/sklearn/utils/_mask.py
Normal file
181
venv/lib/python3.12/site-packages/sklearn/utils/_mask.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import numpy as np
|
||||
from scipy import sparse as sp
|
||||
|
||||
from ._missing import is_scalar_nan
|
||||
from ._param_validation import validate_params
|
||||
from .fixes import _object_dtype_isnan
|
||||
|
||||
|
||||
def _get_dense_mask(X, value_to_mask):
|
||||
with suppress(ImportError, AttributeError):
|
||||
# We also suppress `AttributeError` because older versions of pandas do
|
||||
# not have `NA`.
|
||||
import pandas
|
||||
|
||||
if value_to_mask is pandas.NA:
|
||||
return pandas.isna(X)
|
||||
|
||||
if is_scalar_nan(value_to_mask):
|
||||
if X.dtype.kind == "f":
|
||||
Xt = np.isnan(X)
|
||||
elif X.dtype.kind in ("i", "u"):
|
||||
# can't have NaNs in integer array.
|
||||
Xt = np.zeros(X.shape, dtype=bool)
|
||||
else:
|
||||
# np.isnan does not work on object dtypes.
|
||||
Xt = _object_dtype_isnan(X)
|
||||
else:
|
||||
Xt = X == value_to_mask
|
||||
|
||||
return Xt
|
||||
|
||||
|
||||
def _get_mask(X, value_to_mask):
|
||||
"""Compute the boolean mask X == value_to_mask.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
|
||||
Input data, where ``n_samples`` is the number of samples and
|
||||
``n_features`` is the number of features.
|
||||
|
||||
value_to_mask : {int, float}
|
||||
The value which is to be masked in X.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X_mask : {ndarray, sparse matrix} of shape (n_samples, n_features)
|
||||
Missing mask.
|
||||
"""
|
||||
if not sp.issparse(X):
|
||||
# For all cases apart of a sparse input where we need to reconstruct
|
||||
# a sparse output
|
||||
return _get_dense_mask(X, value_to_mask)
|
||||
|
||||
Xt = _get_dense_mask(X.data, value_to_mask)
|
||||
|
||||
sparse_constructor = sp.csr_matrix if X.format == "csr" else sp.csc_matrix
|
||||
Xt_sparse = sparse_constructor(
|
||||
(Xt, X.indices.copy(), X.indptr.copy()), shape=X.shape, dtype=bool
|
||||
)
|
||||
|
||||
return Xt_sparse
|
||||
|
||||
|
||||
@validate_params(
|
||||
{
|
||||
"X": ["array-like", "sparse matrix"],
|
||||
"mask": ["array-like"],
|
||||
},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def safe_mask(X, mask):
|
||||
"""Return a mask which is safe to use on X.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {array-like, sparse matrix}
|
||||
Data on which to apply mask.
|
||||
|
||||
mask : array-like
|
||||
Mask to be used on X.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mask : ndarray
|
||||
Array that is safe to use on X.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils import safe_mask
|
||||
>>> from scipy.sparse import csr_matrix
|
||||
>>> data = csr_matrix([[1], [2], [3], [4], [5]])
|
||||
>>> condition = [False, True, True, False, True]
|
||||
>>> mask = safe_mask(data, condition)
|
||||
>>> data[mask].toarray()
|
||||
array([[2],
|
||||
[3],
|
||||
[5]])
|
||||
"""
|
||||
mask = np.asarray(mask)
|
||||
if np.issubdtype(mask.dtype, np.signedinteger):
|
||||
return mask
|
||||
|
||||
if hasattr(X, "toarray"):
|
||||
ind = np.arange(mask.shape[0])
|
||||
mask = ind[mask]
|
||||
return mask
|
||||
|
||||
|
||||
def axis0_safe_slice(X, mask, len_mask):
|
||||
"""Return a mask which is safer to use on X than safe_mask.
|
||||
|
||||
This mask is safer than safe_mask since it returns an
|
||||
empty array, when a sparse matrix is sliced with a boolean mask
|
||||
with all False, instead of raising an unhelpful error in older
|
||||
versions of SciPy.
|
||||
|
||||
See: https://github.com/scipy/scipy/issues/5361
|
||||
|
||||
Also note that we can avoid doing the dot product by checking if
|
||||
the len_mask is not zero in _huber_loss_and_gradient but this
|
||||
is not going to be the bottleneck, since the number of outliers
|
||||
and non_outliers are typically non-zero and it makes the code
|
||||
tougher to follow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {array-like, sparse matrix}
|
||||
Data on which to apply mask.
|
||||
|
||||
mask : ndarray
|
||||
Mask to be used on X.
|
||||
|
||||
len_mask : int
|
||||
The length of the mask.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mask : ndarray
|
||||
Array that is safe to use on X.
|
||||
"""
|
||||
if len_mask != 0:
|
||||
return X[safe_mask(X, mask), :]
|
||||
return np.zeros(shape=(0, X.shape[1]))
|
||||
|
||||
|
||||
def indices_to_mask(indices, mask_length):
|
||||
"""Convert list of indices to boolean mask.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indices : list-like
|
||||
List of integers treated as indices.
|
||||
mask_length : int
|
||||
Length of boolean mask to be generated.
|
||||
This parameter must be greater than max(indices).
|
||||
|
||||
Returns
|
||||
-------
|
||||
mask : 1d boolean nd-array
|
||||
Boolean array that is True where indices are present, else False.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils._mask import indices_to_mask
|
||||
>>> indices = [1, 2 , 3, 4]
|
||||
>>> indices_to_mask(indices, 5)
|
||||
array([False, True, True, True, True])
|
||||
"""
|
||||
if mask_length <= np.max(indices):
|
||||
raise ValueError("mask_length must be greater than max(indices)")
|
||||
|
||||
mask = np.zeros(mask_length, dtype=bool)
|
||||
mask[indices] = True
|
||||
|
||||
return mask
|
||||
File diff suppressed because it is too large
Load Diff
68
venv/lib/python3.12/site-packages/sklearn/utils/_missing.py
Normal file
68
venv/lib/python3.12/site-packages/sklearn/utils/_missing.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import math
|
||||
import numbers
|
||||
from contextlib import suppress
|
||||
|
||||
|
||||
def is_scalar_nan(x):
|
||||
"""Test if x is NaN.
|
||||
|
||||
This function is meant to overcome the issue that np.isnan does not allow
|
||||
non-numerical types as input, and that np.nan is not float('nan').
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : any type
|
||||
Any scalar value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Returns true if x is NaN, and false otherwise.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from sklearn.utils._missing import is_scalar_nan
|
||||
>>> is_scalar_nan(np.nan)
|
||||
True
|
||||
>>> is_scalar_nan(float("nan"))
|
||||
True
|
||||
>>> is_scalar_nan(None)
|
||||
False
|
||||
>>> is_scalar_nan("")
|
||||
False
|
||||
>>> is_scalar_nan([np.nan])
|
||||
False
|
||||
"""
|
||||
return (
|
||||
not isinstance(x, numbers.Integral)
|
||||
and isinstance(x, numbers.Real)
|
||||
and math.isnan(x)
|
||||
)
|
||||
|
||||
|
||||
def is_pandas_na(x):
|
||||
"""Test if x is pandas.NA.
|
||||
|
||||
We intentionally do not use this function to return `True` for `pd.NA` in
|
||||
`is_scalar_nan`, because estimators that support `pd.NA` are the exception
|
||||
rather than the rule at the moment. When `pd.NA` is more universally
|
||||
supported, we may reconsider this decision.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : any type
|
||||
|
||||
Returns
|
||||
-------
|
||||
boolean
|
||||
"""
|
||||
with suppress(ImportError):
|
||||
from pandas import NA
|
||||
|
||||
return x is NA
|
||||
|
||||
return False
|
||||
419
venv/lib/python3.12/site-packages/sklearn/utils/_mocking.py
Normal file
419
venv/lib/python3.12/site-packages/sklearn/utils/_mocking.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseEstimator, ClassifierMixin
|
||||
from ..utils._metadata_requests import RequestMethod
|
||||
from .metaestimators import available_if
|
||||
from .validation import (
|
||||
_check_sample_weight,
|
||||
_num_samples,
|
||||
check_array,
|
||||
check_is_fitted,
|
||||
check_random_state,
|
||||
)
|
||||
|
||||
|
||||
class ArraySlicingWrapper:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
array
|
||||
"""
|
||||
|
||||
def __init__(self, array):
|
||||
self.array = array
|
||||
|
||||
def __getitem__(self, aslice):
|
||||
return MockDataFrame(self.array[aslice])
|
||||
|
||||
|
||||
class MockDataFrame:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
array
|
||||
"""
|
||||
|
||||
# have shape and length but don't support indexing.
|
||||
|
||||
def __init__(self, array):
|
||||
self.array = array
|
||||
self.values = array
|
||||
self.shape = array.shape
|
||||
self.ndim = array.ndim
|
||||
# ugly hack to make iloc work.
|
||||
self.iloc = ArraySlicingWrapper(array)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.array)
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
# Pandas data frames also are array-like: we want to make sure that
|
||||
# input validation in cross-validation does not try to call that
|
||||
# method.
|
||||
return self.array
|
||||
|
||||
def __eq__(self, other):
|
||||
return MockDataFrame(self.array == other.array)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def take(self, indices, axis=0):
|
||||
return MockDataFrame(self.array.take(indices, axis=axis))
|
||||
|
||||
|
||||
class CheckingClassifier(ClassifierMixin, BaseEstimator):
|
||||
"""Dummy classifier to test pipelining and meta-estimators.
|
||||
|
||||
Checks some property of `X` and `y`in fit / predict.
|
||||
This allows testing whether pipelines / cross-validation or metaestimators
|
||||
changed the input.
|
||||
|
||||
Can also be used to check if `fit_params` are passed correctly, and
|
||||
to force a certain score to be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
check_y, check_X : callable, default=None
|
||||
The callable used to validate `X` and `y`. These callable should return
|
||||
a bool where `False` will trigger an `AssertionError`. If `None`, the
|
||||
data is not validated. Default is `None`.
|
||||
|
||||
check_y_params, check_X_params : dict, default=None
|
||||
The optional parameters to pass to `check_X` and `check_y`. If `None`,
|
||||
then no parameters are passed in.
|
||||
|
||||
methods_to_check : "all" or list of str, default="all"
|
||||
The methods in which the checks should be applied. By default,
|
||||
all checks will be done on all methods (`fit`, `predict`,
|
||||
`predict_proba`, `decision_function` and `score`).
|
||||
|
||||
foo_param : int, default=0
|
||||
A `foo` param. When `foo > 1`, the output of :meth:`score` will be 1
|
||||
otherwise it is 0.
|
||||
|
||||
expected_sample_weight : bool, default=False
|
||||
Whether to check if a valid `sample_weight` was passed to `fit`.
|
||||
|
||||
expected_fit_params : list of str, default=None
|
||||
A list of the expected parameters given when calling `fit`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
classes_ : int
|
||||
The classes seen during `fit`.
|
||||
|
||||
n_features_in_ : int
|
||||
The number of features seen during `fit`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils._mocking import CheckingClassifier
|
||||
|
||||
This helper allow to assert to specificities regarding `X` or `y`. In this
|
||||
case we expect `check_X` or `check_y` to return a boolean.
|
||||
|
||||
>>> from sklearn.datasets import load_iris
|
||||
>>> X, y = load_iris(return_X_y=True)
|
||||
>>> clf = CheckingClassifier(check_X=lambda x: x.shape == (150, 4))
|
||||
>>> clf.fit(X, y)
|
||||
CheckingClassifier(...)
|
||||
|
||||
We can also provide a check which might raise an error. In this case, we
|
||||
expect `check_X` to return `X` and `check_y` to return `y`.
|
||||
|
||||
>>> from sklearn.utils import check_array
|
||||
>>> clf = CheckingClassifier(check_X=check_array)
|
||||
>>> clf.fit(X, y)
|
||||
CheckingClassifier(...)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
check_y=None,
|
||||
check_y_params=None,
|
||||
check_X=None,
|
||||
check_X_params=None,
|
||||
methods_to_check="all",
|
||||
foo_param=0,
|
||||
expected_sample_weight=None,
|
||||
expected_fit_params=None,
|
||||
random_state=None,
|
||||
):
|
||||
self.check_y = check_y
|
||||
self.check_y_params = check_y_params
|
||||
self.check_X = check_X
|
||||
self.check_X_params = check_X_params
|
||||
self.methods_to_check = methods_to_check
|
||||
self.foo_param = foo_param
|
||||
self.expected_sample_weight = expected_sample_weight
|
||||
self.expected_fit_params = expected_fit_params
|
||||
self.random_state = random_state
|
||||
|
||||
def _check_X_y(self, X, y=None, should_be_fitted=True):
|
||||
"""Validate X and y and make extra check.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
The data set.
|
||||
`X` is checked only if `check_X` is not `None` (default is None).
|
||||
y : array-like of shape (n_samples), default=None
|
||||
The corresponding target, by default `None`.
|
||||
`y` is checked only if `check_y` is not `None` (default is None).
|
||||
should_be_fitted : bool, default=True
|
||||
Whether or not the classifier should be already fitted.
|
||||
By default True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X, y
|
||||
"""
|
||||
if should_be_fitted:
|
||||
check_is_fitted(self)
|
||||
if self.check_X is not None:
|
||||
params = {} if self.check_X_params is None else self.check_X_params
|
||||
checked_X = self.check_X(X, **params)
|
||||
if isinstance(checked_X, (bool, np.bool_)):
|
||||
assert checked_X
|
||||
else:
|
||||
X = checked_X
|
||||
if y is not None and self.check_y is not None:
|
||||
params = {} if self.check_y_params is None else self.check_y_params
|
||||
checked_y = self.check_y(y, **params)
|
||||
if isinstance(checked_y, (bool, np.bool_)):
|
||||
assert checked_y
|
||||
else:
|
||||
y = checked_y
|
||||
return X, y
|
||||
|
||||
def fit(self, X, y, sample_weight=None, **fit_params):
|
||||
"""Fit classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
Training vector, where `n_samples` is the number of samples and
|
||||
`n_features` is the number of features.
|
||||
|
||||
y : array-like of shape (n_samples, n_outputs) or (n_samples,), \
|
||||
default=None
|
||||
Target relative to X for classification or regression;
|
||||
None for unsupervised learning.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights. If None, then samples are equally weighted.
|
||||
|
||||
**fit_params : dict of string -> object
|
||||
Parameters passed to the ``fit`` method of the estimator
|
||||
|
||||
Returns
|
||||
-------
|
||||
self
|
||||
"""
|
||||
assert _num_samples(X) == _num_samples(y)
|
||||
if self.methods_to_check == "all" or "fit" in self.methods_to_check:
|
||||
X, y = self._check_X_y(X, y, should_be_fitted=False)
|
||||
self.n_features_in_ = np.shape(X)[1]
|
||||
self.classes_ = np.unique(check_array(y, ensure_2d=False, allow_nd=True))
|
||||
if self.expected_fit_params:
|
||||
missing = set(self.expected_fit_params) - set(fit_params)
|
||||
if missing:
|
||||
raise AssertionError(
|
||||
f"Expected fit parameter(s) {list(missing)} not seen."
|
||||
)
|
||||
for key, value in fit_params.items():
|
||||
if _num_samples(value) != _num_samples(X):
|
||||
raise AssertionError(
|
||||
f"Fit parameter {key} has length {_num_samples(value)}"
|
||||
f"; expected {_num_samples(X)}."
|
||||
)
|
||||
if self.expected_sample_weight:
|
||||
if sample_weight is None:
|
||||
raise AssertionError("Expected sample_weight to be passed")
|
||||
_check_sample_weight(sample_weight, X)
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
"""Predict the first class seen in `classes_`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
The input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
preds : ndarray of shape (n_samples,)
|
||||
Predictions of the first class seen in `classes_`.
|
||||
"""
|
||||
if self.methods_to_check == "all" or "predict" in self.methods_to_check:
|
||||
X, y = self._check_X_y(X)
|
||||
rng = check_random_state(self.random_state)
|
||||
return rng.choice(self.classes_, size=_num_samples(X))
|
||||
|
||||
def predict_proba(self, X):
|
||||
"""Predict probabilities for each class.
|
||||
|
||||
Here, the dummy classifier will provide a probability of 1 for the
|
||||
first class of `classes_` and 0 otherwise.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
The input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
proba : ndarray of shape (n_samples, n_classes)
|
||||
The probabilities for each sample and class.
|
||||
"""
|
||||
if self.methods_to_check == "all" or "predict_proba" in self.methods_to_check:
|
||||
X, y = self._check_X_y(X)
|
||||
rng = check_random_state(self.random_state)
|
||||
proba = rng.randn(_num_samples(X), len(self.classes_))
|
||||
proba = np.abs(proba, out=proba)
|
||||
proba /= np.sum(proba, axis=1)[:, np.newaxis]
|
||||
return proba
|
||||
|
||||
def decision_function(self, X):
|
||||
"""Confidence score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
The input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
decision : ndarray of shape (n_samples,) if n_classes == 2\
|
||||
else (n_samples, n_classes)
|
||||
Confidence score.
|
||||
"""
|
||||
if (
|
||||
self.methods_to_check == "all"
|
||||
or "decision_function" in self.methods_to_check
|
||||
):
|
||||
X, y = self._check_X_y(X)
|
||||
rng = check_random_state(self.random_state)
|
||||
if len(self.classes_) == 2:
|
||||
# for binary classifier, the confidence score is related to
|
||||
# classes_[1] and therefore should be null.
|
||||
return rng.randn(_num_samples(X))
|
||||
else:
|
||||
return rng.randn(_num_samples(X), len(self.classes_))
|
||||
|
||||
def score(self, X=None, Y=None):
|
||||
"""Fake score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
Input data, where `n_samples` is the number of samples and
|
||||
`n_features` is the number of features.
|
||||
|
||||
Y : array-like of shape (n_samples, n_output) or (n_samples,)
|
||||
Target relative to X for classification or regression;
|
||||
None for unsupervised learning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Either 0 or 1 depending of `foo_param` (i.e. `foo_param > 1 =>
|
||||
score=1` otherwise `score=0`).
|
||||
"""
|
||||
if self.methods_to_check == "all" or "score" in self.methods_to_check:
|
||||
self._check_X_y(X, Y)
|
||||
if self.foo_param > 1:
|
||||
score = 1.0
|
||||
else:
|
||||
score = 0.0
|
||||
return score
|
||||
|
||||
def __sklearn_tags__(self):
|
||||
tags = super().__sklearn_tags__()
|
||||
tags._skip_test = True
|
||||
tags.input_tags.two_d_array = False
|
||||
tags.target_tags.one_d_labels = True
|
||||
return tags
|
||||
|
||||
|
||||
# Deactivate key validation for CheckingClassifier because we want to be able to
|
||||
# call fit with arbitrary fit_params and record them. Without this change, we
|
||||
# would get an error because those arbitrary params are not expected.
|
||||
CheckingClassifier.set_fit_request = RequestMethod( # type: ignore[assignment,method-assign]
|
||||
name="fit", keys=[], validate_keys=False
|
||||
)
|
||||
|
||||
|
||||
class NoSampleWeightWrapper(BaseEstimator):
|
||||
"""Wrap estimator which will not expose `sample_weight`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
est : estimator, default=None
|
||||
The estimator to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, est=None):
|
||||
self.est = est
|
||||
|
||||
def fit(self, X, y):
|
||||
return self.est.fit(X, y)
|
||||
|
||||
def predict(self, X):
|
||||
return self.est.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
return self.est.predict_proba(X)
|
||||
|
||||
def __sklearn_tags__(self):
|
||||
tags = super().__sklearn_tags__()
|
||||
tags._skip_test = True
|
||||
return tags
|
||||
|
||||
|
||||
def _check_response(method):
|
||||
def check(self):
|
||||
return self.response_methods is not None and method in self.response_methods
|
||||
|
||||
return check
|
||||
|
||||
|
||||
class _MockEstimatorOnOffPrediction(BaseEstimator):
|
||||
"""Estimator for which we can turn on/off the prediction methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
response_methods: list of \
|
||||
{"predict", "predict_proba", "decision_function"}, default=None
|
||||
List containing the response implemented by the estimator. When, the
|
||||
response is in the list, it will return the name of the response method
|
||||
when called. Otherwise, an `AttributeError` is raised. It allows to
|
||||
use `getattr` as any conventional estimator. By default, no response
|
||||
methods are mocked.
|
||||
"""
|
||||
|
||||
def __init__(self, response_methods=None):
|
||||
self.response_methods = response_methods
|
||||
|
||||
def fit(self, X, y):
|
||||
self.classes_ = np.unique(y)
|
||||
return self
|
||||
|
||||
@available_if(_check_response("predict"))
|
||||
def predict(self, X):
|
||||
return "predict"
|
||||
|
||||
@available_if(_check_response("predict_proba"))
|
||||
def predict_proba(self, X):
|
||||
return "predict_proba"
|
||||
|
||||
@available_if(_check_response("decision_function"))
|
||||
def decision_function(self, X):
|
||||
return "decision_function"
|
||||
Binary file not shown.
@@ -0,0 +1,33 @@
|
||||
# Helpers to safely access OpenMP routines
|
||||
#
|
||||
# no-op implementations are provided for the case where OpenMP is not available.
|
||||
#
|
||||
# All calls to OpenMP routines should be cimported from this module.
|
||||
|
||||
cdef extern from *:
|
||||
"""
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#define SKLEARN_OPENMP_PARALLELISM_ENABLED 1
|
||||
#else
|
||||
#define SKLEARN_OPENMP_PARALLELISM_ENABLED 0
|
||||
#define omp_lock_t int
|
||||
#define omp_init_lock(l) (void)0
|
||||
#define omp_destroy_lock(l) (void)0
|
||||
#define omp_set_lock(l) (void)0
|
||||
#define omp_unset_lock(l) (void)0
|
||||
#define omp_get_thread_num() 0
|
||||
#define omp_get_max_threads() 1
|
||||
#endif
|
||||
"""
|
||||
bint SKLEARN_OPENMP_PARALLELISM_ENABLED
|
||||
|
||||
ctypedef struct omp_lock_t:
|
||||
pass
|
||||
|
||||
void omp_init_lock(omp_lock_t*) noexcept nogil
|
||||
void omp_destroy_lock(omp_lock_t*) noexcept nogil
|
||||
void omp_set_lock(omp_lock_t*) noexcept nogil
|
||||
void omp_unset_lock(omp_lock_t*) noexcept nogil
|
||||
int omp_get_thread_num() noexcept nogil
|
||||
int omp_get_max_threads() noexcept nogil
|
||||
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from joblib import cpu_count
|
||||
|
||||
|
||||
# Module level cache for cpu_count as we do not expect this to change during
|
||||
# the lifecycle of a Python program. This dictionary is keyed by
|
||||
# only_physical_cores.
|
||||
_CPU_COUNTS = {}
|
||||
|
||||
|
||||
def _openmp_parallelism_enabled():
|
||||
"""Determines whether scikit-learn has been built with OpenMP
|
||||
|
||||
It allows to retrieve at runtime the information gathered at compile time.
|
||||
"""
|
||||
# SKLEARN_OPENMP_PARALLELISM_ENABLED is resolved at compile time and defined
|
||||
# in _openmp_helpers.pxd as a boolean. This function exposes it to Python.
|
||||
return SKLEARN_OPENMP_PARALLELISM_ENABLED
|
||||
|
||||
|
||||
cpdef _openmp_effective_n_threads(n_threads=None, only_physical_cores=True):
|
||||
"""Determine the effective number of threads to be used for OpenMP calls
|
||||
|
||||
- For ``n_threads = None``,
|
||||
- if the ``OMP_NUM_THREADS`` environment variable is set, return
|
||||
``openmp.omp_get_max_threads()``
|
||||
- otherwise, return the minimum between ``openmp.omp_get_max_threads()``
|
||||
and the number of cpus, taking cgroups quotas into account. Cgroups
|
||||
quotas can typically be set by tools such as Docker.
|
||||
The result of ``omp_get_max_threads`` can be influenced by environment
|
||||
variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``.
|
||||
|
||||
- For ``n_threads > 0``, return this as the maximal number of threads for
|
||||
parallel OpenMP calls.
|
||||
|
||||
- For ``n_threads < 0``, return the maximal number of threads minus
|
||||
``|n_threads + 1|``. In particular ``n_threads = -1`` will use as many
|
||||
threads as there are available cores on the machine.
|
||||
|
||||
- Raise a ValueError for ``n_threads = 0``.
|
||||
|
||||
Passing the `only_physical_cores=False` flag makes it possible to use extra
|
||||
threads for SMT/HyperThreading logical cores. It has been empirically
|
||||
observed that using as many threads as available SMT cores can slightly
|
||||
improve the performance in some cases, but can severely degrade
|
||||
performance other times. Therefore it is recommended to use
|
||||
`only_physical_cores=True` unless an empirical study has been conducted to
|
||||
assess the impact of SMT on a case-by-case basis (using various input data
|
||||
shapes, in particular small data shapes).
|
||||
|
||||
If scikit-learn is built without OpenMP support, always return 1.
|
||||
"""
|
||||
if n_threads == 0:
|
||||
raise ValueError("n_threads = 0 is invalid")
|
||||
|
||||
if not SKLEARN_OPENMP_PARALLELISM_ENABLED:
|
||||
# OpenMP disabled at build-time => sequential mode
|
||||
return 1
|
||||
|
||||
if os.getenv("OMP_NUM_THREADS"):
|
||||
# Fall back to user provided number of threads making it possible
|
||||
# to exceed the number of cpus.
|
||||
max_n_threads = omp_get_max_threads()
|
||||
else:
|
||||
try:
|
||||
n_cpus = _CPU_COUNTS[only_physical_cores]
|
||||
except KeyError:
|
||||
n_cpus = cpu_count(only_physical_cores=only_physical_cores)
|
||||
_CPU_COUNTS[only_physical_cores] = n_cpus
|
||||
max_n_threads = min(omp_get_max_threads(), n_cpus)
|
||||
|
||||
if n_threads is None:
|
||||
return max_n_threads
|
||||
elif n_threads < 0:
|
||||
return max(1, max_n_threads + n_threads + 1)
|
||||
|
||||
return n_threads
|
||||
@@ -0,0 +1,46 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
|
||||
def check_matplotlib_support(caller_name):
|
||||
"""Raise ImportError with detailed error message if mpl is not installed.
|
||||
|
||||
Plot utilities like any of the Display's plotting functions should lazily import
|
||||
matplotlib and call this helper before any computation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
caller_name : str
|
||||
The name of the caller that requires matplotlib.
|
||||
"""
|
||||
try:
|
||||
import matplotlib # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"{} requires matplotlib. You can install matplotlib with "
|
||||
"`pip install matplotlib`".format(caller_name)
|
||||
) from e
|
||||
|
||||
|
||||
def check_pandas_support(caller_name):
|
||||
"""Raise ImportError with detailed error message if pandas is not installed.
|
||||
|
||||
Plot utilities like :func:`fetch_openml` should lazily import
|
||||
pandas and call this helper before any computation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
caller_name : str
|
||||
The name of the caller that requires pandas.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas
|
||||
The pandas package.
|
||||
"""
|
||||
try:
|
||||
import pandas
|
||||
|
||||
return pandas
|
||||
except ImportError as e:
|
||||
raise ImportError("{} requires pandas.".format(caller_name)) from e
|
||||
@@ -0,0 +1,910 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from inspect import signature
|
||||
from numbers import Integral, Real
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix, issparse
|
||||
|
||||
from .._config import config_context, get_config
|
||||
from .validation import _is_arraylike_not_scalar
|
||||
|
||||
|
||||
class InvalidParameterError(ValueError, TypeError):
|
||||
"""Custom exception to be raised when the parameter of a class/method/function
|
||||
does not have a valid type or value.
|
||||
"""
|
||||
|
||||
# Inherits from ValueError and TypeError to keep backward compatibility.
|
||||
|
||||
|
||||
def validate_parameter_constraints(parameter_constraints, params, caller_name):
|
||||
"""Validate types and values of given parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameter_constraints : dict or {"no_validation"}
|
||||
If "no_validation", validation is skipped for this parameter.
|
||||
|
||||
If a dict, it must be a dictionary `param_name: list of constraints`.
|
||||
A parameter is valid if it satisfies one of the constraints from the list.
|
||||
Constraints can be:
|
||||
- an Interval object, representing a continuous or discrete range of numbers
|
||||
- the string "array-like"
|
||||
- the string "sparse matrix"
|
||||
- the string "random_state"
|
||||
- callable
|
||||
- None, meaning that None is a valid value for the parameter
|
||||
- any type, meaning that any instance of this type is valid
|
||||
- an Options object, representing a set of elements of a given type
|
||||
- a StrOptions object, representing a set of strings
|
||||
- the string "boolean"
|
||||
- the string "verbose"
|
||||
- the string "cv_object"
|
||||
- the string "nan"
|
||||
- a MissingValues object representing markers for missing values
|
||||
- a HasMethods object, representing method(s) an object must have
|
||||
- a Hidden object, representing a constraint not meant to be exposed to the user
|
||||
|
||||
params : dict
|
||||
A dictionary `param_name: param_value`. The parameters to validate against the
|
||||
constraints.
|
||||
|
||||
caller_name : str
|
||||
The name of the estimator or function or method that called this function.
|
||||
"""
|
||||
for param_name, param_val in params.items():
|
||||
# We allow parameters to not have a constraint so that third party estimators
|
||||
# can inherit from sklearn estimators without having to necessarily use the
|
||||
# validation tools.
|
||||
if param_name not in parameter_constraints:
|
||||
continue
|
||||
|
||||
constraints = parameter_constraints[param_name]
|
||||
|
||||
if constraints == "no_validation":
|
||||
continue
|
||||
|
||||
constraints = [make_constraint(constraint) for constraint in constraints]
|
||||
|
||||
for constraint in constraints:
|
||||
if constraint.is_satisfied_by(param_val):
|
||||
# this constraint is satisfied, no need to check further.
|
||||
break
|
||||
else:
|
||||
# No constraint is satisfied, raise with an informative message.
|
||||
|
||||
# Ignore constraints that we don't want to expose in the error message,
|
||||
# i.e. options that are for internal purpose or not officially supported.
|
||||
constraints = [
|
||||
constraint for constraint in constraints if not constraint.hidden
|
||||
]
|
||||
|
||||
if len(constraints) == 1:
|
||||
constraints_str = f"{constraints[0]}"
|
||||
else:
|
||||
constraints_str = (
|
||||
f"{', '.join([str(c) for c in constraints[:-1]])} or"
|
||||
f" {constraints[-1]}"
|
||||
)
|
||||
|
||||
raise InvalidParameterError(
|
||||
f"The {param_name!r} parameter of {caller_name} must be"
|
||||
f" {constraints_str}. Got {param_val!r} instead."
|
||||
)
|
||||
|
||||
|
||||
def make_constraint(constraint):
|
||||
"""Convert the constraint into the appropriate Constraint object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constraint : object
|
||||
The constraint to convert.
|
||||
|
||||
Returns
|
||||
-------
|
||||
constraint : instance of _Constraint
|
||||
The converted constraint.
|
||||
"""
|
||||
if isinstance(constraint, str) and constraint == "array-like":
|
||||
return _ArrayLikes()
|
||||
if isinstance(constraint, str) and constraint == "sparse matrix":
|
||||
return _SparseMatrices()
|
||||
if isinstance(constraint, str) and constraint == "random_state":
|
||||
return _RandomStates()
|
||||
if constraint is callable:
|
||||
return _Callables()
|
||||
if constraint is None:
|
||||
return _NoneConstraint()
|
||||
if isinstance(constraint, type):
|
||||
return _InstancesOf(constraint)
|
||||
if isinstance(
|
||||
constraint, (Interval, StrOptions, Options, HasMethods, MissingValues)
|
||||
):
|
||||
return constraint
|
||||
if isinstance(constraint, str) and constraint == "boolean":
|
||||
return _Booleans()
|
||||
if isinstance(constraint, str) and constraint == "verbose":
|
||||
return _VerboseHelper()
|
||||
if isinstance(constraint, str) and constraint == "cv_object":
|
||||
return _CVObjects()
|
||||
if isinstance(constraint, Hidden):
|
||||
constraint = make_constraint(constraint.constraint)
|
||||
constraint.hidden = True
|
||||
return constraint
|
||||
if (isinstance(constraint, str) and constraint == "nan") or (
|
||||
isinstance(constraint, float) and np.isnan(constraint)
|
||||
):
|
||||
return _NanConstraint()
|
||||
raise ValueError(f"Unknown constraint type: {constraint}")
|
||||
|
||||
|
||||
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
|
||||
"""Decorator to validate types and values of functions and methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameter_constraints : dict
|
||||
A dictionary `param_name: list of constraints`. See the docstring of
|
||||
`validate_parameter_constraints` for a description of the accepted constraints.
|
||||
|
||||
Note that the *args and **kwargs parameters are not validated and must not be
|
||||
present in the parameter_constraints dictionary.
|
||||
|
||||
prefer_skip_nested_validation : bool
|
||||
If True, the validation of parameters of inner estimators or functions
|
||||
called by the decorated function will be skipped.
|
||||
|
||||
This is useful to avoid validating many times the parameters passed by the
|
||||
user from the public facing API. It's also useful to avoid validating
|
||||
parameters that we pass internally to inner functions that are guaranteed to
|
||||
be valid by the test suite.
|
||||
|
||||
It should be set to True for most functions, except for those that receive
|
||||
non-validated objects as parameters or that are just wrappers around classes
|
||||
because they only perform a partial validation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
decorated_function : function or method
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# The dict of parameter constraints is set as an attribute of the function
|
||||
# to make it possible to dynamically introspect the constraints for
|
||||
# automatic testing.
|
||||
setattr(func, "_skl_parameter_constraints", parameter_constraints)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
global_skip_validation = get_config()["skip_parameter_validation"]
|
||||
if global_skip_validation:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
func_sig = signature(func)
|
||||
|
||||
# Map *args/**kwargs to the function signature
|
||||
params = func_sig.bind(*args, **kwargs)
|
||||
params.apply_defaults()
|
||||
|
||||
# ignore self/cls and positional/keyword markers
|
||||
to_ignore = [
|
||||
p.name
|
||||
for p in func_sig.parameters.values()
|
||||
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
||||
]
|
||||
to_ignore += ["self", "cls"]
|
||||
params = {k: v for k, v in params.arguments.items() if k not in to_ignore}
|
||||
|
||||
validate_parameter_constraints(
|
||||
parameter_constraints, params, caller_name=func.__qualname__
|
||||
)
|
||||
|
||||
try:
|
||||
with config_context(
|
||||
skip_parameter_validation=(
|
||||
prefer_skip_nested_validation or global_skip_validation
|
||||
)
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
except InvalidParameterError as e:
|
||||
# When the function is just a wrapper around an estimator, we allow
|
||||
# the function to delegate validation to the estimator, but we replace
|
||||
# the name of the estimator by the name of the function in the error
|
||||
# message to avoid confusion.
|
||||
msg = re.sub(
|
||||
r"parameter of \w+ must be",
|
||||
f"parameter of {func.__qualname__} must be",
|
||||
str(e),
|
||||
)
|
||||
raise InvalidParameterError(msg) from e
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RealNotInt(Real):
|
||||
"""A type that represents reals that are not instances of int.
|
||||
|
||||
Behaves like float, but also works with values extracted from numpy arrays.
|
||||
isintance(1, RealNotInt) -> False
|
||||
isinstance(1.0, RealNotInt) -> True
|
||||
"""
|
||||
|
||||
|
||||
RealNotInt.register(float)
|
||||
|
||||
|
||||
def _type_name(t):
|
||||
"""Convert type into human readable string."""
|
||||
module = t.__module__
|
||||
qualname = t.__qualname__
|
||||
if module == "builtins":
|
||||
return qualname
|
||||
elif t == Real:
|
||||
return "float"
|
||||
elif t == Integral:
|
||||
return "int"
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
class _Constraint(ABC):
|
||||
"""Base class for the constraint objects."""
|
||||
|
||||
def __init__(self):
|
||||
self.hidden = False
|
||||
|
||||
@abstractmethod
|
||||
def is_satisfied_by(self, val):
|
||||
"""Whether or not a value satisfies the constraint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
val : object
|
||||
The value to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
is_satisfied : bool
|
||||
Whether or not the constraint is satisfied by this value.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self):
|
||||
"""A human readable representational string of the constraint."""
|
||||
|
||||
|
||||
class _InstancesOf(_Constraint):
|
||||
"""Constraint representing instances of a given type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type : type
|
||||
The valid type.
|
||||
"""
|
||||
|
||||
def __init__(self, type):
|
||||
super().__init__()
|
||||
self.type = type
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return isinstance(val, self.type)
|
||||
|
||||
def __str__(self):
|
||||
return f"an instance of {_type_name(self.type)!r}"
|
||||
|
||||
|
||||
class _NoneConstraint(_Constraint):
|
||||
"""Constraint representing the None singleton."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return val is None
|
||||
|
||||
def __str__(self):
|
||||
return "None"
|
||||
|
||||
|
||||
class _NanConstraint(_Constraint):
|
||||
"""Constraint representing the indicator `np.nan`."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return (
|
||||
not isinstance(val, Integral) and isinstance(val, Real) and math.isnan(val)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "numpy.nan"
|
||||
|
||||
|
||||
class _PandasNAConstraint(_Constraint):
|
||||
"""Constraint representing the indicator `pd.NA`."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return isinstance(val, type(pd.NA)) and pd.isna(val)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return "pandas.NA"
|
||||
|
||||
|
||||
class Options(_Constraint):
|
||||
"""Constraint representing a finite set of instances of a given type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type : type
|
||||
|
||||
options : set
|
||||
The set of valid scalars.
|
||||
|
||||
deprecated : set or None, default=None
|
||||
A subset of the `options` to mark as deprecated in the string
|
||||
representation of the constraint.
|
||||
"""
|
||||
|
||||
def __init__(self, type, options, *, deprecated=None):
|
||||
super().__init__()
|
||||
self.type = type
|
||||
self.options = options
|
||||
self.deprecated = deprecated or set()
|
||||
|
||||
if self.deprecated - self.options:
|
||||
raise ValueError("The deprecated options must be a subset of the options.")
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return isinstance(val, self.type) and val in self.options
|
||||
|
||||
def _mark_if_deprecated(self, option):
|
||||
"""Add a deprecated mark to an option if needed."""
|
||||
option_str = f"{option!r}"
|
||||
if option in self.deprecated:
|
||||
option_str = f"{option_str} (deprecated)"
|
||||
return option_str
|
||||
|
||||
def __str__(self):
|
||||
options_str = (
|
||||
f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}"
|
||||
)
|
||||
return f"a {_type_name(self.type)} among {{{options_str}}}"
|
||||
|
||||
|
||||
class StrOptions(Options):
|
||||
"""Constraint representing a finite set of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
options : set of str
|
||||
The set of valid strings.
|
||||
|
||||
deprecated : set of str or None, default=None
|
||||
A subset of the `options` to mark as deprecated in the string
|
||||
representation of the constraint.
|
||||
"""
|
||||
|
||||
def __init__(self, options, *, deprecated=None):
|
||||
super().__init__(type=str, options=options, deprecated=deprecated)
|
||||
|
||||
|
||||
class Interval(_Constraint):
|
||||
"""Constraint representing a typed interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type : {numbers.Integral, numbers.Real, RealNotInt}
|
||||
The set of numbers in which to set the interval.
|
||||
|
||||
If RealNotInt, only reals that don't have the integer type
|
||||
are allowed. For example 1.0 is allowed but 1 is not.
|
||||
|
||||
left : float or int or None
|
||||
The left bound of the interval. None means left bound is -∞.
|
||||
|
||||
right : float, int or None
|
||||
The right bound of the interval. None means right bound is +∞.
|
||||
|
||||
closed : {"left", "right", "both", "neither"}
|
||||
Whether the interval is open or closed. Possible choices are:
|
||||
|
||||
- `"left"`: the interval is closed on the left and open on the right.
|
||||
It is equivalent to the interval `[ left, right )`.
|
||||
- `"right"`: the interval is closed on the right and open on the left.
|
||||
It is equivalent to the interval `( left, right ]`.
|
||||
- `"both"`: the interval is closed.
|
||||
It is equivalent to the interval `[ left, right ]`.
|
||||
- `"neither"`: the interval is open.
|
||||
It is equivalent to the interval `( left, right )`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Setting a bound to `None` and setting the interval closed is valid. For instance,
|
||||
strictly speaking, `Interval(Real, 0, None, closed="both")` corresponds to
|
||||
`[0, +∞) U {+∞}`.
|
||||
"""
|
||||
|
||||
def __init__(self, type, left, right, *, closed):
|
||||
super().__init__()
|
||||
self.type = type
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.closed = closed
|
||||
|
||||
self._check_params()
|
||||
|
||||
def _check_params(self):
|
||||
if self.type not in (Integral, Real, RealNotInt):
|
||||
raise ValueError(
|
||||
"type must be either numbers.Integral, numbers.Real or RealNotInt."
|
||||
f" Got {self.type} instead."
|
||||
)
|
||||
|
||||
if self.closed not in ("left", "right", "both", "neither"):
|
||||
raise ValueError(
|
||||
"closed must be either 'left', 'right', 'both' or 'neither'. "
|
||||
f"Got {self.closed} instead."
|
||||
)
|
||||
|
||||
if self.type is Integral:
|
||||
suffix = "for an interval over the integers."
|
||||
if self.left is not None and not isinstance(self.left, Integral):
|
||||
raise TypeError(f"Expecting left to be an int {suffix}")
|
||||
if self.right is not None and not isinstance(self.right, Integral):
|
||||
raise TypeError(f"Expecting right to be an int {suffix}")
|
||||
if self.left is None and self.closed in ("left", "both"):
|
||||
raise ValueError(
|
||||
f"left can't be None when closed == {self.closed} {suffix}"
|
||||
)
|
||||
if self.right is None and self.closed in ("right", "both"):
|
||||
raise ValueError(
|
||||
f"right can't be None when closed == {self.closed} {suffix}"
|
||||
)
|
||||
else:
|
||||
if self.left is not None and not isinstance(self.left, Real):
|
||||
raise TypeError("Expecting left to be a real number.")
|
||||
if self.right is not None and not isinstance(self.right, Real):
|
||||
raise TypeError("Expecting right to be a real number.")
|
||||
|
||||
if self.right is not None and self.left is not None and self.right <= self.left:
|
||||
raise ValueError(
|
||||
f"right can't be less than left. Got left={self.left} and "
|
||||
f"right={self.right}"
|
||||
)
|
||||
|
||||
def __contains__(self, val):
|
||||
if not isinstance(val, Integral) and np.isnan(val):
|
||||
return False
|
||||
|
||||
left_cmp = operator.lt if self.closed in ("left", "both") else operator.le
|
||||
right_cmp = operator.gt if self.closed in ("right", "both") else operator.ge
|
||||
|
||||
left = -np.inf if self.left is None else self.left
|
||||
right = np.inf if self.right is None else self.right
|
||||
|
||||
if left_cmp(val, left):
|
||||
return False
|
||||
if right_cmp(val, right):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
if not isinstance(val, self.type):
|
||||
return False
|
||||
|
||||
return val in self
|
||||
|
||||
def __str__(self):
|
||||
type_str = "an int" if self.type is Integral else "a float"
|
||||
left_bracket = "[" if self.closed in ("left", "both") else "("
|
||||
left_bound = "-inf" if self.left is None else self.left
|
||||
right_bound = "inf" if self.right is None else self.right
|
||||
right_bracket = "]" if self.closed in ("right", "both") else ")"
|
||||
|
||||
# better repr if the bounds were given as integers
|
||||
if not self.type == Integral and isinstance(self.left, Real):
|
||||
left_bound = float(left_bound)
|
||||
if not self.type == Integral and isinstance(self.right, Real):
|
||||
right_bound = float(right_bound)
|
||||
|
||||
return (
|
||||
f"{type_str} in the range "
|
||||
f"{left_bracket}{left_bound}, {right_bound}{right_bracket}"
|
||||
)
|
||||
|
||||
|
||||
class _ArrayLikes(_Constraint):
|
||||
"""Constraint representing array-likes"""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return _is_arraylike_not_scalar(val)
|
||||
|
||||
def __str__(self):
|
||||
return "an array-like"
|
||||
|
||||
|
||||
class _SparseMatrices(_Constraint):
|
||||
"""Constraint representing sparse matrices."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return issparse(val)
|
||||
|
||||
def __str__(self):
|
||||
return "a sparse matrix"
|
||||
|
||||
|
||||
class _Callables(_Constraint):
|
||||
"""Constraint representing callables."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return callable(val)
|
||||
|
||||
def __str__(self):
|
||||
return "a callable"
|
||||
|
||||
|
||||
class _RandomStates(_Constraint):
|
||||
"""Constraint representing random states.
|
||||
|
||||
Convenience class for
|
||||
[Interval(Integral, 0, 2**32 - 1, closed="both"), np.random.RandomState, None]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._constraints = [
|
||||
Interval(Integral, 0, 2**32 - 1, closed="both"),
|
||||
_InstancesOf(np.random.RandomState),
|
||||
_NoneConstraint(),
|
||||
]
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return any(c.is_satisfied_by(val) for c in self._constraints)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
|
||||
f" {self._constraints[-1]}"
|
||||
)
|
||||
|
||||
|
||||
class _Booleans(_Constraint):
|
||||
"""Constraint representing boolean likes.
|
||||
|
||||
Convenience class for
|
||||
[bool, np.bool_]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._constraints = [
|
||||
_InstancesOf(bool),
|
||||
_InstancesOf(np.bool_),
|
||||
]
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return any(c.is_satisfied_by(val) for c in self._constraints)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
|
||||
f" {self._constraints[-1]}"
|
||||
)
|
||||
|
||||
|
||||
class _VerboseHelper(_Constraint):
|
||||
"""Helper constraint for the verbose parameter.
|
||||
|
||||
Convenience class for
|
||||
[Interval(Integral, 0, None, closed="left"), bool, numpy.bool_]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._constraints = [
|
||||
Interval(Integral, 0, None, closed="left"),
|
||||
_InstancesOf(bool),
|
||||
_InstancesOf(np.bool_),
|
||||
]
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return any(c.is_satisfied_by(val) for c in self._constraints)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
|
||||
f" {self._constraints[-1]}"
|
||||
)
|
||||
|
||||
|
||||
class MissingValues(_Constraint):
|
||||
"""Helper constraint for the `missing_values` parameters.
|
||||
|
||||
Convenience for
|
||||
[
|
||||
Integral,
|
||||
Interval(Real, None, None, closed="both"),
|
||||
str, # when numeric_only is False
|
||||
None, # when numeric_only is False
|
||||
_NanConstraint(),
|
||||
_PandasNAConstraint(),
|
||||
]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
numeric_only : bool, default=False
|
||||
Whether to consider only numeric missing value markers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, numeric_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.numeric_only = numeric_only
|
||||
|
||||
self._constraints = [
|
||||
_InstancesOf(Integral),
|
||||
# we use an interval of Real to ignore np.nan that has its own constraint
|
||||
Interval(Real, None, None, closed="both"),
|
||||
_NanConstraint(),
|
||||
_PandasNAConstraint(),
|
||||
]
|
||||
if not self.numeric_only:
|
||||
self._constraints.extend([_InstancesOf(str), _NoneConstraint()])
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return any(c.is_satisfied_by(val) for c in self._constraints)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
|
||||
f" {self._constraints[-1]}"
|
||||
)
|
||||
|
||||
|
||||
class HasMethods(_Constraint):
|
||||
"""Constraint representing objects that expose specific methods.
|
||||
|
||||
It is useful for parameters following a protocol and where we don't want to impose
|
||||
an affiliation to a specific module or class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
methods : str or list of str
|
||||
The method(s) that the object is expected to expose.
|
||||
"""
|
||||
|
||||
@validate_params(
|
||||
{"methods": [str, list]},
|
||||
prefer_skip_nested_validation=True,
|
||||
)
|
||||
def __init__(self, methods):
|
||||
super().__init__()
|
||||
if isinstance(methods, str):
|
||||
methods = [methods]
|
||||
self.methods = methods
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return all(callable(getattr(val, method, None)) for method in self.methods)
|
||||
|
||||
def __str__(self):
|
||||
if len(self.methods) == 1:
|
||||
methods = f"{self.methods[0]!r}"
|
||||
else:
|
||||
methods = (
|
||||
f"{', '.join([repr(m) for m in self.methods[:-1]])} and"
|
||||
f" {self.methods[-1]!r}"
|
||||
)
|
||||
return f"an object implementing {methods}"
|
||||
|
||||
|
||||
class _IterablesNotString(_Constraint):
|
||||
"""Constraint representing iterables that are not strings."""
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return isinstance(val, Iterable) and not isinstance(val, str)
|
||||
|
||||
def __str__(self):
|
||||
return "an iterable"
|
||||
|
||||
|
||||
class _CVObjects(_Constraint):
|
||||
"""Constraint representing cv objects.
|
||||
|
||||
Convenient class for
|
||||
[
|
||||
Interval(Integral, 2, None, closed="left"),
|
||||
HasMethods(["split", "get_n_splits"]),
|
||||
_IterablesNotString(),
|
||||
None,
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._constraints = [
|
||||
Interval(Integral, 2, None, closed="left"),
|
||||
HasMethods(["split", "get_n_splits"]),
|
||||
_IterablesNotString(),
|
||||
_NoneConstraint(),
|
||||
]
|
||||
|
||||
def is_satisfied_by(self, val):
|
||||
return any(c.is_satisfied_by(val) for c in self._constraints)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
|
||||
f" {self._constraints[-1]}"
|
||||
)
|
||||
|
||||
|
||||
class Hidden:
|
||||
"""Class encapsulating a constraint not meant to be exposed to the user.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constraint : str or _Constraint instance
|
||||
The constraint to be used internally.
|
||||
"""
|
||||
|
||||
def __init__(self, constraint):
|
||||
self.constraint = constraint
|
||||
|
||||
|
||||
def generate_invalid_param_val(constraint):
|
||||
"""Return a value that does not satisfy the constraint.
|
||||
|
||||
Raises a NotImplementedError if there exists no invalid value for this constraint.
|
||||
|
||||
This is only useful for testing purpose.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constraint : _Constraint instance
|
||||
The constraint to generate a value for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val : object
|
||||
A value that does not satisfy the constraint.
|
||||
"""
|
||||
if isinstance(constraint, StrOptions):
|
||||
return f"not {' or '.join(constraint.options)}"
|
||||
|
||||
if isinstance(constraint, MissingValues):
|
||||
return np.array([1, 2, 3])
|
||||
|
||||
if isinstance(constraint, _VerboseHelper):
|
||||
return -1
|
||||
|
||||
if isinstance(constraint, HasMethods):
|
||||
return type("HasNotMethods", (), {})()
|
||||
|
||||
if isinstance(constraint, _IterablesNotString):
|
||||
return "a string"
|
||||
|
||||
if isinstance(constraint, _CVObjects):
|
||||
return "not a cv object"
|
||||
|
||||
if isinstance(constraint, Interval) and constraint.type is Integral:
|
||||
if constraint.left is not None:
|
||||
return constraint.left - 1
|
||||
if constraint.right is not None:
|
||||
return constraint.right + 1
|
||||
|
||||
# There's no integer outside (-inf, +inf)
|
||||
raise NotImplementedError
|
||||
|
||||
if isinstance(constraint, Interval) and constraint.type in (Real, RealNotInt):
|
||||
if constraint.left is not None:
|
||||
return constraint.left - 1e-6
|
||||
if constraint.right is not None:
|
||||
return constraint.right + 1e-6
|
||||
|
||||
# bounds are -inf, +inf
|
||||
if constraint.closed in ("right", "neither"):
|
||||
return -np.inf
|
||||
if constraint.closed in ("left", "neither"):
|
||||
return np.inf
|
||||
|
||||
# interval is [-inf, +inf]
|
||||
return np.nan
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def generate_valid_param(constraint):
|
||||
"""Return a value that does satisfy a constraint.
|
||||
|
||||
This is only useful for testing purpose.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constraint : Constraint instance
|
||||
The constraint to generate a value for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val : object
|
||||
A value that does satisfy the constraint.
|
||||
"""
|
||||
if isinstance(constraint, _ArrayLikes):
|
||||
return np.array([1, 2, 3])
|
||||
|
||||
if isinstance(constraint, _SparseMatrices):
|
||||
return csr_matrix([[0, 1], [1, 0]])
|
||||
|
||||
if isinstance(constraint, _RandomStates):
|
||||
return np.random.RandomState(42)
|
||||
|
||||
if isinstance(constraint, _Callables):
|
||||
return lambda x: x
|
||||
|
||||
if isinstance(constraint, _NoneConstraint):
|
||||
return None
|
||||
|
||||
if isinstance(constraint, _InstancesOf):
|
||||
if constraint.type is np.ndarray:
|
||||
# special case for ndarray since it can't be instantiated without arguments
|
||||
return np.array([1, 2, 3])
|
||||
|
||||
if constraint.type in (Integral, Real):
|
||||
# special case for Integral and Real since they are abstract classes
|
||||
return 1
|
||||
|
||||
return constraint.type()
|
||||
|
||||
if isinstance(constraint, _Booleans):
|
||||
return True
|
||||
|
||||
if isinstance(constraint, _VerboseHelper):
|
||||
return 1
|
||||
|
||||
if isinstance(constraint, MissingValues) and constraint.numeric_only:
|
||||
return np.nan
|
||||
|
||||
if isinstance(constraint, MissingValues) and not constraint.numeric_only:
|
||||
return "missing"
|
||||
|
||||
if isinstance(constraint, HasMethods):
|
||||
return type(
|
||||
"ValidHasMethods", (), {m: lambda self: None for m in constraint.methods}
|
||||
)()
|
||||
|
||||
if isinstance(constraint, _IterablesNotString):
|
||||
return [1, 2, 3]
|
||||
|
||||
if isinstance(constraint, _CVObjects):
|
||||
return 5
|
||||
|
||||
if isinstance(constraint, Options): # includes StrOptions
|
||||
for option in constraint.options:
|
||||
return option
|
||||
|
||||
if isinstance(constraint, Interval):
|
||||
interval = constraint
|
||||
if interval.left is None and interval.right is None:
|
||||
return 0
|
||||
elif interval.left is None:
|
||||
return interval.right - 1
|
||||
elif interval.right is None:
|
||||
return interval.left + 1
|
||||
else:
|
||||
if interval.type is Real:
|
||||
return (interval.left + interval.right) / 2
|
||||
else:
|
||||
return interval.left + 1
|
||||
|
||||
raise ValueError(f"Unknown constraint type: {constraint}")
|
||||
419
venv/lib/python3.12/site-packages/sklearn/utils/_plotting.py
Normal file
419
venv/lib/python3.12/site-packages/sklearn/utils/_plotting.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import check_consistent_length
|
||||
from ._optional_dependencies import check_matplotlib_support
|
||||
from ._response import _get_response_values_binary
|
||||
from .fixes import parse_version
|
||||
from .multiclass import type_of_target
|
||||
from .validation import _check_pos_label_consistency, _num_samples
|
||||
|
||||
|
||||
class _BinaryClassifierCurveDisplayMixin:
|
||||
"""Mixin class to be used in Displays requiring a binary classifier.
|
||||
|
||||
The aim of this class is to centralize some validations regarding the estimator and
|
||||
the target and gather the response of the estimator.
|
||||
"""
|
||||
|
||||
def _validate_plot_params(self, *, ax=None, name=None):
|
||||
check_matplotlib_support(f"{self.__class__.__name__}.plot")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots()
|
||||
|
||||
# Display classes are in process of changing from `estimator_name` to `name`.
|
||||
# Try old attr name: `estimator_name` first.
|
||||
if name is None:
|
||||
name = getattr(self, "estimator_name", getattr(self, "name", None))
|
||||
return ax, ax.figure, name
|
||||
|
||||
@classmethod
|
||||
def _validate_and_get_response_values(
|
||||
cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None
|
||||
):
|
||||
check_matplotlib_support(f"{cls.__name__}.from_estimator")
|
||||
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
|
||||
y_pred, pos_label = _get_response_values_binary(
|
||||
estimator,
|
||||
X,
|
||||
response_method=response_method,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return y_pred, pos_label, name
|
||||
|
||||
@classmethod
|
||||
def _validate_from_predictions_params(
|
||||
cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None
|
||||
):
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
if type_of_target(y_true) != "binary":
|
||||
raise ValueError(
|
||||
f"The target y is not binary. Got {type_of_target(y_true)} type of"
|
||||
" target."
|
||||
)
|
||||
|
||||
check_consistent_length(y_true, y_pred, sample_weight)
|
||||
pos_label = _check_pos_label_consistency(pos_label, y_true)
|
||||
|
||||
name = name if name is not None else "Classifier"
|
||||
|
||||
return pos_label, name
|
||||
|
||||
@classmethod
|
||||
def _validate_from_cv_results_params(
|
||||
cls,
|
||||
cv_results,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight,
|
||||
pos_label,
|
||||
):
|
||||
check_matplotlib_support(f"{cls.__name__}.from_cv_results")
|
||||
|
||||
required_keys = {"estimator", "indices"}
|
||||
if not all(key in cv_results for key in required_keys):
|
||||
raise ValueError(
|
||||
"`cv_results` does not contain one of the following required keys: "
|
||||
f"{required_keys}. Set explicitly the parameters "
|
||||
"`return_estimator=True` and `return_indices=True` to the function"
|
||||
"`cross_validate`."
|
||||
)
|
||||
|
||||
train_size, test_size = (
|
||||
len(cv_results["indices"]["train"][0]),
|
||||
len(cv_results["indices"]["test"][0]),
|
||||
)
|
||||
|
||||
if _num_samples(X) != train_size + test_size:
|
||||
raise ValueError(
|
||||
"`X` does not contain the correct number of samples. "
|
||||
f"Expected {train_size + test_size}, got {_num_samples(X)}."
|
||||
)
|
||||
|
||||
if type_of_target(y) != "binary":
|
||||
raise ValueError(
|
||||
f"The target `y` is not binary. Got {type_of_target(y)} type of target."
|
||||
)
|
||||
check_consistent_length(X, y, sample_weight)
|
||||
|
||||
try:
|
||||
pos_label = _check_pos_label_consistency(pos_label, y)
|
||||
except ValueError as e:
|
||||
# Adapt error message
|
||||
raise ValueError(str(e).replace("y_true", "y"))
|
||||
|
||||
return pos_label
|
||||
|
||||
@staticmethod
|
||||
def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name):
|
||||
"""Helper to get legend label using `name` and `legend_metric`"""
|
||||
if curve_legend_metric is not None and curve_name is not None:
|
||||
label = f"{curve_name} ({legend_metric_name} = {curve_legend_metric:0.2f})"
|
||||
elif curve_legend_metric is not None:
|
||||
label = f"{legend_metric_name} = {curve_legend_metric:0.2f}"
|
||||
elif curve_name is not None:
|
||||
label = curve_name
|
||||
else:
|
||||
label = None
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def _validate_curve_kwargs(
|
||||
n_curves,
|
||||
name,
|
||||
legend_metric,
|
||||
legend_metric_name,
|
||||
curve_kwargs,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get validated line kwargs for each curve.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_curves : int
|
||||
Number of curves.
|
||||
|
||||
name : list of str or None
|
||||
Name for labeling legend entries.
|
||||
|
||||
legend_metric : dict
|
||||
Dictionary with "mean" and "std" keys, or "metric" key of metric
|
||||
values for each curve. If None, "label" will not contain metric values.
|
||||
|
||||
legend_metric_name : str
|
||||
Name of the summary value provided in `legend_metrics`.
|
||||
|
||||
curve_kwargs : dict or list of dict or None
|
||||
Dictionary with keywords passed to the matplotlib's `plot` function
|
||||
to draw the individual curves. If a list is provided, the
|
||||
parameters are applied to the curves sequentially. If a single
|
||||
dictionary is provided, the same parameters are applied to all
|
||||
curves.
|
||||
|
||||
**kwargs : dict
|
||||
Deprecated. Keyword arguments to be passed to matplotlib's `plot`.
|
||||
"""
|
||||
# TODO(1.9): Remove deprecated **kwargs
|
||||
if curve_kwargs and kwargs:
|
||||
raise ValueError(
|
||||
"Cannot provide both `curve_kwargs` and `kwargs`. `**kwargs` is "
|
||||
"deprecated in 1.7 and will be removed in 1.9. Pass all matplotlib "
|
||||
"arguments to `curve_kwargs` as a dictionary."
|
||||
)
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"`**kwargs` is deprecated and will be removed in 1.9. Pass all "
|
||||
"matplotlib arguments to `curve_kwargs` as a dictionary instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
curve_kwargs = kwargs
|
||||
|
||||
if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves:
|
||||
raise ValueError(
|
||||
f"`curve_kwargs` must be None, a dictionary or a list of length "
|
||||
f"{n_curves}. Got: {curve_kwargs}."
|
||||
)
|
||||
|
||||
# Ensure valid `name` and `curve_kwargs` combination.
|
||||
if (
|
||||
isinstance(name, list)
|
||||
and len(name) != 1
|
||||
and not isinstance(curve_kwargs, list)
|
||||
):
|
||||
raise ValueError(
|
||||
"To avoid labeling individual curves that have the same appearance, "
|
||||
f"`curve_kwargs` should be a list of {n_curves} dictionaries. "
|
||||
"Alternatively, set `name` to `None` or a single string to label "
|
||||
"a single legend entry with mean ROC AUC score of all curves."
|
||||
)
|
||||
|
||||
# Ensure `name` is of the correct length
|
||||
if isinstance(name, str):
|
||||
name = [name]
|
||||
if isinstance(name, list) and len(name) == 1:
|
||||
name = name * n_curves
|
||||
name = [None] * n_curves if name is None else name
|
||||
|
||||
# Ensure `curve_kwargs` is of correct length
|
||||
if isinstance(curve_kwargs, Mapping):
|
||||
curve_kwargs = [curve_kwargs] * n_curves
|
||||
|
||||
default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"}
|
||||
if curve_kwargs is None:
|
||||
if n_curves > 1:
|
||||
curve_kwargs = [default_multi_curve_kwargs] * n_curves
|
||||
else:
|
||||
curve_kwargs = [{}]
|
||||
|
||||
labels = []
|
||||
if "mean" in legend_metric:
|
||||
label_aggregate = _BinaryClassifierCurveDisplayMixin._get_legend_label(
|
||||
legend_metric["mean"], name[0], legend_metric_name
|
||||
)
|
||||
# Note: "std" always `None` when "mean" is `None` - no metric value added
|
||||
# to label in this case
|
||||
if legend_metric["std"] is not None:
|
||||
# Add the "+/- std" to the end (in brackets if name provided)
|
||||
if name[0] is not None:
|
||||
label_aggregate = (
|
||||
label_aggregate[:-1] + f" +/- {legend_metric['std']:0.2f})"
|
||||
)
|
||||
else:
|
||||
label_aggregate = (
|
||||
label_aggregate + f" +/- {legend_metric['std']:0.2f}"
|
||||
)
|
||||
# Add `label` for first curve only, set to `None` for remaining curves
|
||||
labels.extend([label_aggregate] + [None] * (n_curves - 1))
|
||||
else:
|
||||
for curve_legend_metric, curve_name in zip(legend_metric["metric"], name):
|
||||
labels.append(
|
||||
_BinaryClassifierCurveDisplayMixin._get_legend_label(
|
||||
curve_legend_metric, curve_name, legend_metric_name
|
||||
)
|
||||
)
|
||||
|
||||
curve_kwargs_ = [
|
||||
_validate_style_kwargs({"label": label}, curve_kwargs[fold_idx])
|
||||
for fold_idx, label in enumerate(labels)
|
||||
]
|
||||
return curve_kwargs_
|
||||
|
||||
|
||||
def _validate_score_name(score_name, scoring, negate_score):
|
||||
"""Validate the `score_name` parameter.
|
||||
|
||||
If `score_name` is provided, we just return it as-is.
|
||||
If `score_name` is `None`, we use `Score` if `negate_score` is `False` and
|
||||
`Negative score` otherwise.
|
||||
If `score_name` is a string or a callable, we infer the name. We replace `_` by
|
||||
spaces and capitalize the first letter. We remove `neg_` and replace it by
|
||||
`"Negative"` if `negate_score` is `False` or just remove it otherwise.
|
||||
"""
|
||||
if score_name is not None:
|
||||
return score_name
|
||||
elif scoring is None:
|
||||
return "Negative score" if negate_score else "Score"
|
||||
else:
|
||||
score_name = scoring.__name__ if callable(scoring) else scoring
|
||||
if negate_score:
|
||||
if score_name.startswith("neg_"):
|
||||
score_name = score_name[4:]
|
||||
else:
|
||||
score_name = f"Negative {score_name}"
|
||||
elif score_name.startswith("neg_"):
|
||||
score_name = f"Negative {score_name[4:]}"
|
||||
score_name = score_name.replace("_", " ")
|
||||
return score_name.capitalize()
|
||||
|
||||
|
||||
def _interval_max_min_ratio(data):
|
||||
"""Compute the ratio between the largest and smallest inter-point distances.
|
||||
|
||||
A value larger than 5 typically indicates that the parameter range would
|
||||
better be displayed with a log scale while a linear scale would be more
|
||||
suitable otherwise.
|
||||
"""
|
||||
diff = np.diff(np.sort(data))
|
||||
return diff.max() / diff.min()
|
||||
|
||||
|
||||
def _validate_style_kwargs(default_style_kwargs, user_style_kwargs):
|
||||
"""Create valid style kwargs by avoiding Matplotlib alias errors.
|
||||
|
||||
Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and
|
||||
'ls', are specified together. To avoid this, we automatically keep only the one
|
||||
specified by the user and raise an error if the user specifies both.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
default_style_kwargs : dict
|
||||
The Matplotlib style kwargs used by default in the scikit-learn display.
|
||||
user_style_kwargs : dict
|
||||
The user-defined Matplotlib style kwargs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
valid_style_kwargs : dict
|
||||
The validated style kwargs taking into account both default and user-defined
|
||||
Matplotlib style kwargs.
|
||||
"""
|
||||
|
||||
invalid_to_valid_kw = {
|
||||
"ls": "linestyle",
|
||||
"c": "color",
|
||||
"ec": "edgecolor",
|
||||
"fc": "facecolor",
|
||||
"lw": "linewidth",
|
||||
"mec": "markeredgecolor",
|
||||
"mfcalt": "markerfacecoloralt",
|
||||
"ms": "markersize",
|
||||
"mew": "markeredgewidth",
|
||||
"mfc": "markerfacecolor",
|
||||
"aa": "antialiased",
|
||||
"ds": "drawstyle",
|
||||
"font": "fontproperties",
|
||||
"family": "fontfamily",
|
||||
"name": "fontname",
|
||||
"size": "fontsize",
|
||||
"stretch": "fontstretch",
|
||||
"style": "fontstyle",
|
||||
"variant": "fontvariant",
|
||||
"weight": "fontweight",
|
||||
"ha": "horizontalalignment",
|
||||
"va": "verticalalignment",
|
||||
"ma": "multialignment",
|
||||
}
|
||||
for invalid_key, valid_key in invalid_to_valid_kw.items():
|
||||
if invalid_key in user_style_kwargs and valid_key in user_style_kwargs:
|
||||
raise TypeError(
|
||||
f"Got both {invalid_key} and {valid_key}, which are aliases of one "
|
||||
"another"
|
||||
)
|
||||
valid_style_kwargs = default_style_kwargs.copy()
|
||||
|
||||
for key in user_style_kwargs.keys():
|
||||
if key in invalid_to_valid_kw:
|
||||
valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key]
|
||||
else:
|
||||
valid_style_kwargs[key] = user_style_kwargs[key]
|
||||
|
||||
return valid_style_kwargs
|
||||
|
||||
|
||||
def _despine(ax):
|
||||
"""Remove the top and right spines of the plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib.axes.Axes
|
||||
The axes of the plot to despine.
|
||||
"""
|
||||
for s in ["top", "right"]:
|
||||
ax.spines[s].set_visible(False)
|
||||
for s in ["bottom", "left"]:
|
||||
ax.spines[s].set_bounds(0, 1)
|
||||
|
||||
|
||||
def _deprecate_estimator_name(estimator_name, name, version):
|
||||
"""Deprecate `estimator_name` in favour of `name`."""
|
||||
version = parse_version(version)
|
||||
version_remove = f"{version.major}.{version.minor + 2}"
|
||||
if estimator_name != "deprecated":
|
||||
if name:
|
||||
raise ValueError(
|
||||
"Cannot provide both `estimator_name` and `name`. `estimator_name` "
|
||||
f"is deprecated in {version} and will be removed in {version_remove}. "
|
||||
"Use `name` only."
|
||||
)
|
||||
warnings.warn(
|
||||
f"`estimator_name` is deprecated in {version} and will be removed in "
|
||||
f"{version_remove}. Use `name` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return estimator_name
|
||||
return name
|
||||
|
||||
|
||||
def _convert_to_list_leaving_none(param):
|
||||
"""Convert parameters to a list, leaving `None` as is."""
|
||||
if param is None:
|
||||
return None
|
||||
if isinstance(param, list):
|
||||
return param
|
||||
return [param]
|
||||
|
||||
|
||||
def _check_param_lengths(required, optional, class_name):
|
||||
"""Check required and optional parameters are of the same length."""
|
||||
optional_provided = {}
|
||||
for name, param in optional.items():
|
||||
if isinstance(param, list):
|
||||
optional_provided[name] = param
|
||||
|
||||
all_params = {**required, **optional_provided}
|
||||
if len({len(param) for param in all_params.values()}) > 1:
|
||||
param_keys = [key for key in all_params.keys()]
|
||||
# Note: below code requires `len(param_keys) >= 2`, which is the case for all
|
||||
# display classes
|
||||
params_formatted = " and ".join([", ".join(param_keys[:-1]), param_keys[-1]])
|
||||
or_plot = ""
|
||||
if "'name' (or self.name)" in param_keys:
|
||||
or_plot = " (or `plot`)"
|
||||
lengths_formatted = ", ".join(
|
||||
f"{key}: {len(value)}" for key, value in all_params.items()
|
||||
)
|
||||
raise ValueError(
|
||||
f"{params_formatted} from `{class_name}` initialization{or_plot}, "
|
||||
f"should all be lists of the same length. Got: {lengths_formatted}"
|
||||
)
|
||||
463
venv/lib/python3.12/site-packages/sklearn/utils/_pprint.py
Normal file
463
venv/lib/python3.12/site-packages/sklearn/utils/_pprint.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""This module contains the _EstimatorPrettyPrinter class used in
|
||||
BaseEstimator.__repr__ for pretty-printing estimators"""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
|
||||
# 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018 Python Software Foundation;
|
||||
# All Rights Reserved
|
||||
|
||||
# Authors: Fred L. Drake, Jr. <fdrake@acm.org> (built-in CPython pprint module)
|
||||
# Nicolas Hug (scikit-learn specific changes)
|
||||
|
||||
# License: PSF License version 2 (see below)
|
||||
|
||||
# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
||||
# --------------------------------------------
|
||||
|
||||
# 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"),
|
||||
# and the Individual or Organization ("Licensee") accessing and otherwise
|
||||
# using this software ("Python") in source or binary form and its associated
|
||||
# documentation.
|
||||
|
||||
# 2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||
# grants Licensee a nonexclusive, royalty-free, world-wide license to
|
||||
# reproduce, analyze, test, perform and/or display publicly, prepare
|
||||
# derivative works, distribute, and otherwise use Python alone or in any
|
||||
# derivative version, provided, however, that PSF's License Agreement and
|
||||
# PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004,
|
||||
# 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016,
|
||||
# 2017, 2018 Python Software Foundation; All Rights Reserved" are retained in
|
||||
# Python alone or in any derivative version prepared by Licensee.
|
||||
|
||||
# 3. In the event Licensee prepares a derivative work that is based on or
|
||||
# incorporates Python or any part thereof, and wants to make the derivative
|
||||
# work available to others as provided herein, then Licensee hereby agrees to
|
||||
# include in any such work a brief summary of the changes made to Python.
|
||||
|
||||
# 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES
|
||||
# NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT
|
||||
# NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF
|
||||
# MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
|
||||
# PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
|
||||
|
||||
# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY
|
||||
# INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF
|
||||
# MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE
|
||||
# THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||
|
||||
# 6. This License Agreement will automatically terminate upon a material
|
||||
# breach of its terms and conditions.
|
||||
|
||||
# 7. Nothing in this License Agreement shall be deemed to create any
|
||||
# relationship of agency, partnership, or joint venture between PSF and
|
||||
# Licensee. This License Agreement does not grant permission to use PSF
|
||||
# trademarks or trade name in a trademark sense to endorse or promote products
|
||||
# or services of Licensee, or any third party.
|
||||
|
||||
# 8. By copying, installing or otherwise using Python, Licensee agrees to be
|
||||
# bound by the terms and conditions of this License Agreement.
|
||||
|
||||
|
||||
# Brief summary of changes to original code:
|
||||
# - "compact" parameter is supported for dicts, not just lists or tuples
|
||||
# - estimators have a custom handler, they're not just treated as objects
|
||||
# - long sequences (lists, tuples, dict items) with more than N elements are
|
||||
# shortened using ellipsis (', ...') at the end.
|
||||
|
||||
import inspect
|
||||
import pprint
|
||||
|
||||
from .._config import get_config
|
||||
from ..base import BaseEstimator
|
||||
from ._missing import is_scalar_nan
|
||||
|
||||
|
||||
class KeyValTuple(tuple):
|
||||
"""Dummy class for correctly rendering key-value tuples from dicts."""
|
||||
|
||||
def __repr__(self):
|
||||
# needed for _dispatch[tuple.__repr__] not to be overridden
|
||||
return super().__repr__()
|
||||
|
||||
|
||||
class KeyValTupleParam(KeyValTuple):
|
||||
"""Dummy class for correctly rendering key-value tuples from parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _changed_params(estimator):
|
||||
"""Return dict (param_name: value) of parameters that were given to
|
||||
estimator with non-default values."""
|
||||
|
||||
params = estimator.get_params(deep=False)
|
||||
init_func = getattr(estimator.__init__, "deprecated_original", estimator.__init__)
|
||||
init_params = inspect.signature(init_func).parameters
|
||||
init_params = {name: param.default for name, param in init_params.items()}
|
||||
|
||||
def has_changed(k, v):
|
||||
if k not in init_params: # happens if k is part of a **kwargs
|
||||
return True
|
||||
if init_params[k] == inspect._empty: # k has no default value
|
||||
return True
|
||||
# try to avoid calling repr on nested estimators
|
||||
if isinstance(v, BaseEstimator) and v.__class__ != init_params[k].__class__:
|
||||
return True
|
||||
# Use repr as a last resort. It may be expensive.
|
||||
if repr(v) != repr(init_params[k]) and not (
|
||||
is_scalar_nan(init_params[k]) and is_scalar_nan(v)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
return {k: v for k, v in params.items() if has_changed(k, v)}
|
||||
|
||||
|
||||
class _EstimatorPrettyPrinter(pprint.PrettyPrinter):
|
||||
"""Pretty Printer class for estimator objects.
|
||||
|
||||
This extends the pprint.PrettyPrinter class, because:
|
||||
- we need estimators to be printed with their parameters, e.g.
|
||||
Estimator(param1=value1, ...) which is not supported by default.
|
||||
- the 'compact' parameter of PrettyPrinter is ignored for dicts, which
|
||||
may lead to very long representations that we want to avoid.
|
||||
|
||||
Quick overview of pprint.PrettyPrinter (see also
|
||||
https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
|
||||
|
||||
- the entry point is the _format() method which calls format() (overridden
|
||||
here)
|
||||
- format() directly calls _safe_repr() for a first try at rendering the
|
||||
object
|
||||
- _safe_repr formats the whole object recursively, only calling itself,
|
||||
not caring about line length or anything
|
||||
- back to _format(), if the output string is too long, _format() then calls
|
||||
the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
|
||||
the type of the object. This where the line length and the compact
|
||||
parameters are taken into account.
|
||||
- those _pprint_TYPE() methods will internally use the format() method for
|
||||
rendering the nested objects of an object (e.g. the elements of a list)
|
||||
|
||||
In the end, everything has to be implemented twice: in _safe_repr and in
|
||||
the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
|
||||
straightforward to extend (especially when we want a compact output), so
|
||||
the code is a bit convoluted.
|
||||
|
||||
This class overrides:
|
||||
- format() to support the changed_only parameter
|
||||
- _safe_repr to support printing of estimators (for when they fit on a
|
||||
single line)
|
||||
- _format_dict_items so that dict are correctly 'compacted'
|
||||
- _format_items so that ellipsis is used on long lists and tuples
|
||||
|
||||
When estimators cannot be printed on a single line, the builtin _format()
|
||||
will call _pprint_estimator() because it was registered to do so (see
|
||||
_dispatch[BaseEstimator.__repr__] = _pprint_estimator).
|
||||
|
||||
both _format_dict_items() and _pprint_estimator() use the
|
||||
_format_params_or_dict_items() method that will format parameters and
|
||||
key-value pairs respecting the compact parameter. This method needs another
|
||||
subroutine _pprint_key_val_tuple() used when a parameter or a key-value
|
||||
pair is too long to fit on a single line. This subroutine is called in
|
||||
_format() and is registered as well in the _dispatch dict (just like
|
||||
_pprint_estimator). We had to create the two classes KeyValTuple and
|
||||
KeyValTupleParam for this.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indent=1,
|
||||
width=80,
|
||||
depth=None,
|
||||
stream=None,
|
||||
*,
|
||||
compact=False,
|
||||
indent_at_name=True,
|
||||
n_max_elements_to_show=None,
|
||||
):
|
||||
super().__init__(indent, width, depth, stream, compact=compact)
|
||||
self._indent_at_name = indent_at_name
|
||||
if self._indent_at_name:
|
||||
self._indent_per_level = 1 # ignore indent param
|
||||
self._changed_only = get_config()["print_changed_only"]
|
||||
# Max number of elements in a list, dict, tuple until we start using
|
||||
# ellipsis. This also affects the number of arguments of an estimators
|
||||
# (they are treated as dicts)
|
||||
self.n_max_elements_to_show = n_max_elements_to_show
|
||||
|
||||
def format(self, object, context, maxlevels, level):
|
||||
return _safe_repr(
|
||||
object, context, maxlevels, level, changed_only=self._changed_only
|
||||
)
|
||||
|
||||
def _pprint_estimator(self, object, stream, indent, allowance, context, level):
|
||||
stream.write(object.__class__.__name__ + "(")
|
||||
if self._indent_at_name:
|
||||
indent += len(object.__class__.__name__)
|
||||
|
||||
if self._changed_only:
|
||||
params = _changed_params(object)
|
||||
else:
|
||||
params = object.get_params(deep=False)
|
||||
|
||||
self._format_params(
|
||||
sorted(params.items()), stream, indent, allowance + 1, context, level
|
||||
)
|
||||
stream.write(")")
|
||||
|
||||
def _format_dict_items(self, items, stream, indent, allowance, context, level):
|
||||
return self._format_params_or_dict_items(
|
||||
items, stream, indent, allowance, context, level, is_dict=True
|
||||
)
|
||||
|
||||
def _format_params(self, items, stream, indent, allowance, context, level):
|
||||
return self._format_params_or_dict_items(
|
||||
items, stream, indent, allowance, context, level, is_dict=False
|
||||
)
|
||||
|
||||
def _format_params_or_dict_items(
|
||||
self, object, stream, indent, allowance, context, level, is_dict
|
||||
):
|
||||
"""Format dict items or parameters respecting the compact=True
|
||||
parameter. For some reason, the builtin rendering of dict items doesn't
|
||||
respect compact=True and will use one line per key-value if all cannot
|
||||
fit in a single line.
|
||||
Dict items will be rendered as <'key': value> while params will be
|
||||
rendered as <key=value>. The implementation is mostly copy/pasting from
|
||||
the builtin _format_items().
|
||||
This also adds ellipsis if the number of items is greater than
|
||||
self.n_max_elements_to_show.
|
||||
"""
|
||||
write = stream.write
|
||||
indent += self._indent_per_level
|
||||
delimnl = ",\n" + " " * indent
|
||||
delim = ""
|
||||
width = max_width = self._width - indent + 1
|
||||
it = iter(object)
|
||||
try:
|
||||
next_ent = next(it)
|
||||
except StopIteration:
|
||||
return
|
||||
last = False
|
||||
n_items = 0
|
||||
while not last:
|
||||
if n_items == self.n_max_elements_to_show:
|
||||
write(", ...")
|
||||
break
|
||||
n_items += 1
|
||||
ent = next_ent
|
||||
try:
|
||||
next_ent = next(it)
|
||||
except StopIteration:
|
||||
last = True
|
||||
max_width -= allowance
|
||||
width -= allowance
|
||||
if self._compact:
|
||||
k, v = ent
|
||||
krepr = self._repr(k, context, level)
|
||||
vrepr = self._repr(v, context, level)
|
||||
if not is_dict:
|
||||
krepr = krepr.strip("'")
|
||||
middle = ": " if is_dict else "="
|
||||
rep = krepr + middle + vrepr
|
||||
w = len(rep) + 2
|
||||
if width < w:
|
||||
width = max_width
|
||||
if delim:
|
||||
delim = delimnl
|
||||
if width >= w:
|
||||
width -= w
|
||||
write(delim)
|
||||
delim = ", "
|
||||
write(rep)
|
||||
continue
|
||||
write(delim)
|
||||
delim = delimnl
|
||||
class_ = KeyValTuple if is_dict else KeyValTupleParam
|
||||
self._format(
|
||||
class_(ent), stream, indent, allowance if last else 1, context, level
|
||||
)
|
||||
|
||||
def _format_items(self, items, stream, indent, allowance, context, level):
|
||||
"""Format the items of an iterable (list, tuple...). Same as the
|
||||
built-in _format_items, with support for ellipsis if the number of
|
||||
elements is greater than self.n_max_elements_to_show.
|
||||
"""
|
||||
write = stream.write
|
||||
indent += self._indent_per_level
|
||||
if self._indent_per_level > 1:
|
||||
write((self._indent_per_level - 1) * " ")
|
||||
delimnl = ",\n" + " " * indent
|
||||
delim = ""
|
||||
width = max_width = self._width - indent + 1
|
||||
it = iter(items)
|
||||
try:
|
||||
next_ent = next(it)
|
||||
except StopIteration:
|
||||
return
|
||||
last = False
|
||||
n_items = 0
|
||||
while not last:
|
||||
if n_items == self.n_max_elements_to_show:
|
||||
write(", ...")
|
||||
break
|
||||
n_items += 1
|
||||
ent = next_ent
|
||||
try:
|
||||
next_ent = next(it)
|
||||
except StopIteration:
|
||||
last = True
|
||||
max_width -= allowance
|
||||
width -= allowance
|
||||
if self._compact:
|
||||
rep = self._repr(ent, context, level)
|
||||
w = len(rep) + 2
|
||||
if width < w:
|
||||
width = max_width
|
||||
if delim:
|
||||
delim = delimnl
|
||||
if width >= w:
|
||||
width -= w
|
||||
write(delim)
|
||||
delim = ", "
|
||||
write(rep)
|
||||
continue
|
||||
write(delim)
|
||||
delim = delimnl
|
||||
self._format(ent, stream, indent, allowance if last else 1, context, level)
|
||||
|
||||
def _pprint_key_val_tuple(self, object, stream, indent, allowance, context, level):
|
||||
"""Pretty printing for key-value tuples from dict or parameters."""
|
||||
k, v = object
|
||||
rep = self._repr(k, context, level)
|
||||
if isinstance(object, KeyValTupleParam):
|
||||
rep = rep.strip("'")
|
||||
middle = "="
|
||||
else:
|
||||
middle = ": "
|
||||
stream.write(rep)
|
||||
stream.write(middle)
|
||||
self._format(
|
||||
v, stream, indent + len(rep) + len(middle), allowance, context, level
|
||||
)
|
||||
|
||||
# Note: need to copy _dispatch to prevent instances of the builtin
|
||||
# PrettyPrinter class to call methods of _EstimatorPrettyPrinter (see issue
|
||||
# 12906)
|
||||
# mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
|
||||
_dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore[attr-defined]
|
||||
_dispatch[BaseEstimator.__repr__] = _pprint_estimator
|
||||
_dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
|
||||
|
||||
|
||||
def _safe_repr(object, context, maxlevels, level, changed_only=False):
|
||||
"""Same as the builtin _safe_repr, with added support for Estimator
|
||||
objects."""
|
||||
typ = type(object)
|
||||
|
||||
if typ in pprint._builtin_scalars:
|
||||
return repr(object), True, False
|
||||
|
||||
r = getattr(typ, "__repr__", None)
|
||||
if issubclass(typ, dict) and r is dict.__repr__:
|
||||
if not object:
|
||||
return "{}", True, False
|
||||
objid = id(object)
|
||||
if maxlevels and level >= maxlevels:
|
||||
return "{...}", False, objid in context
|
||||
if objid in context:
|
||||
return pprint._recursion(object), False, True
|
||||
context[objid] = 1
|
||||
readable = True
|
||||
recursive = False
|
||||
components = []
|
||||
append = components.append
|
||||
level += 1
|
||||
saferepr = _safe_repr
|
||||
items = sorted(object.items(), key=pprint._safe_tuple)
|
||||
for k, v in items:
|
||||
krepr, kreadable, krecur = saferepr(
|
||||
k, context, maxlevels, level, changed_only=changed_only
|
||||
)
|
||||
vrepr, vreadable, vrecur = saferepr(
|
||||
v, context, maxlevels, level, changed_only=changed_only
|
||||
)
|
||||
append("%s: %s" % (krepr, vrepr))
|
||||
readable = readable and kreadable and vreadable
|
||||
if krecur or vrecur:
|
||||
recursive = True
|
||||
del context[objid]
|
||||
return "{%s}" % ", ".join(components), readable, recursive
|
||||
|
||||
if (issubclass(typ, list) and r is list.__repr__) or (
|
||||
issubclass(typ, tuple) and r is tuple.__repr__
|
||||
):
|
||||
if issubclass(typ, list):
|
||||
if not object:
|
||||
return "[]", True, False
|
||||
format = "[%s]"
|
||||
elif len(object) == 1:
|
||||
format = "(%s,)"
|
||||
else:
|
||||
if not object:
|
||||
return "()", True, False
|
||||
format = "(%s)"
|
||||
objid = id(object)
|
||||
if maxlevels and level >= maxlevels:
|
||||
return format % "...", False, objid in context
|
||||
if objid in context:
|
||||
return pprint._recursion(object), False, True
|
||||
context[objid] = 1
|
||||
readable = True
|
||||
recursive = False
|
||||
components = []
|
||||
append = components.append
|
||||
level += 1
|
||||
for o in object:
|
||||
orepr, oreadable, orecur = _safe_repr(
|
||||
o, context, maxlevels, level, changed_only=changed_only
|
||||
)
|
||||
append(orepr)
|
||||
if not oreadable:
|
||||
readable = False
|
||||
if orecur:
|
||||
recursive = True
|
||||
del context[objid]
|
||||
return format % ", ".join(components), readable, recursive
|
||||
|
||||
if issubclass(typ, BaseEstimator):
|
||||
objid = id(object)
|
||||
if maxlevels and level >= maxlevels:
|
||||
return f"{typ.__name__}(...)", False, objid in context
|
||||
if objid in context:
|
||||
return pprint._recursion(object), False, True
|
||||
context[objid] = 1
|
||||
readable = True
|
||||
recursive = False
|
||||
if changed_only:
|
||||
params = _changed_params(object)
|
||||
else:
|
||||
params = object.get_params(deep=False)
|
||||
components = []
|
||||
append = components.append
|
||||
level += 1
|
||||
saferepr = _safe_repr
|
||||
items = sorted(params.items(), key=pprint._safe_tuple)
|
||||
for k, v in items:
|
||||
krepr, kreadable, krecur = saferepr(
|
||||
k, context, maxlevels, level, changed_only=changed_only
|
||||
)
|
||||
vrepr, vreadable, vrecur = saferepr(
|
||||
v, context, maxlevels, level, changed_only=changed_only
|
||||
)
|
||||
append("%s=%s" % (krepr.strip("'"), vrepr))
|
||||
readable = readable and kreadable and vreadable
|
||||
if krecur or vrecur:
|
||||
recursive = True
|
||||
del context[objid]
|
||||
return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
|
||||
|
||||
rep = repr(object)
|
||||
return rep, (rep and not rep.startswith("<")), False
|
||||
Binary file not shown.
34
venv/lib/python3.12/site-packages/sklearn/utils/_random.pxd
Normal file
34
venv/lib/python3.12/site-packages/sklearn/utils/_random.pxd
Normal file
@@ -0,0 +1,34 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from ._typedefs cimport uint32_t
|
||||
|
||||
|
||||
cdef inline uint32_t DEFAULT_SEED = 1
|
||||
|
||||
cdef enum:
|
||||
# Max value for our rand_r replacement (near the bottom).
|
||||
# We don't use RAND_MAX because it's different across platforms and
|
||||
# particularly tiny on Windows/MSVC.
|
||||
# It corresponds to the maximum representable value for
|
||||
# 32-bit signed integers (i.e. 2^31 - 1).
|
||||
RAND_R_MAX = 2147483647
|
||||
|
||||
|
||||
# rand_r replacement using a 32bit XorShift generator
|
||||
# See http://www.jstatsoft.org/v08/i14/paper for details
|
||||
cdef inline uint32_t our_rand_r(uint32_t* seed) nogil:
|
||||
"""Generate a pseudo-random np.uint32 from a np.uint32 seed"""
|
||||
# seed shouldn't ever be 0.
|
||||
if (seed[0] == 0):
|
||||
seed[0] = DEFAULT_SEED
|
||||
|
||||
seed[0] ^= <uint32_t>(seed[0] << 13)
|
||||
seed[0] ^= <uint32_t>(seed[0] >> 17)
|
||||
seed[0] ^= <uint32_t>(seed[0] << 5)
|
||||
|
||||
# Use the modulo to make sure that we don't return a values greater than the
|
||||
# maximum representable value for signed 32bit integers (i.e. 2^31 - 1).
|
||||
# Note that the parenthesis are needed to avoid overflow: here
|
||||
# RAND_R_MAX is cast to uint32_t before 1 is added.
|
||||
return seed[0] % ((<uint32_t>RAND_R_MAX) + 1)
|
||||
355
venv/lib/python3.12/site-packages/sklearn/utils/_random.pyx
Normal file
355
venv/lib/python3.12/site-packages/sklearn/utils/_random.pyx
Normal file
@@ -0,0 +1,355 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""
|
||||
Random utility function
|
||||
=======================
|
||||
This module complements missing features of ``numpy.random``.
|
||||
|
||||
The module contains:
|
||||
* Several algorithms to sample integers without replacement.
|
||||
* Fast rand_r alternative based on xor shifts
|
||||
"""
|
||||
import numpy as np
|
||||
from . import check_random_state
|
||||
|
||||
from ._typedefs cimport intp_t
|
||||
|
||||
|
||||
cdef uint32_t DEFAULT_SEED = 1
|
||||
|
||||
|
||||
# Compatibility type to always accept the default int type used by NumPy, both
|
||||
# before and after NumPy 2. On Windows, `long` does not always match `inp_t`.
|
||||
# See the comments in the `sample_without_replacement` Python function for more
|
||||
# details.
|
||||
ctypedef fused default_int:
|
||||
intp_t
|
||||
long
|
||||
|
||||
|
||||
cpdef _sample_without_replacement_check_input(default_int n_population,
|
||||
default_int n_samples):
|
||||
""" Check that input are consistent for sample_without_replacement"""
|
||||
if n_population < 0:
|
||||
raise ValueError('n_population should be greater than 0, got %s.'
|
||||
% n_population)
|
||||
|
||||
if n_samples > n_population:
|
||||
raise ValueError('n_population should be greater or equal than '
|
||||
'n_samples, got n_samples > n_population (%s > %s)'
|
||||
% (n_samples, n_population))
|
||||
|
||||
|
||||
cpdef _sample_without_replacement_with_tracking_selection(
|
||||
default_int n_population,
|
||||
default_int n_samples,
|
||||
random_state=None):
|
||||
r"""Sample integers without replacement.
|
||||
|
||||
Select n_samples integers from the set [0, n_population) without
|
||||
replacement.
|
||||
|
||||
Time complexity:
|
||||
- Worst-case: unbounded
|
||||
- Average-case:
|
||||
O(O(np.random.randint) * \sum_{i=1}^n_samples 1 /
|
||||
(1 - i / n_population)))
|
||||
<= O(O(np.random.randint) *
|
||||
n_population * ln((n_population - 2)
|
||||
/(n_population - 1 - n_samples)))
|
||||
<= O(O(np.random.randint) *
|
||||
n_population * 1 / (1 - n_samples / n_population))
|
||||
|
||||
Space complexity of O(n_samples) in a python set.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_population : int
|
||||
The size of the set to sample from.
|
||||
|
||||
n_samples : int
|
||||
The number of integer to sample.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
If int, random_state is the seed used by the random number generator;
|
||||
If RandomState instance, random_state is the random number generator;
|
||||
If None, the random number generator is the RandomState instance used
|
||||
by `np.random`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray of shape (n_samples,)
|
||||
The sampled subsets of integer.
|
||||
"""
|
||||
_sample_without_replacement_check_input(n_population, n_samples)
|
||||
|
||||
cdef default_int i
|
||||
cdef default_int j
|
||||
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
|
||||
|
||||
rng = check_random_state(random_state)
|
||||
rng_randint = rng.randint
|
||||
|
||||
# The following line of code are heavily inspired from python core,
|
||||
# more precisely of random.sample.
|
||||
cdef set selected = set()
|
||||
|
||||
for i in range(n_samples):
|
||||
j = rng_randint(n_population)
|
||||
while j in selected:
|
||||
j = rng_randint(n_population)
|
||||
selected.add(j)
|
||||
out[i] = j
|
||||
|
||||
return np.asarray(out)
|
||||
|
||||
|
||||
cpdef _sample_without_replacement_with_pool(default_int n_population,
|
||||
default_int n_samples,
|
||||
random_state=None):
|
||||
"""Sample integers without replacement.
|
||||
|
||||
Select n_samples integers from the set [0, n_population) without
|
||||
replacement.
|
||||
|
||||
Time complexity: O(n_population + O(np.random.randint) * n_samples)
|
||||
|
||||
Space complexity of O(n_population + n_samples).
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_population : int
|
||||
The size of the set to sample from.
|
||||
|
||||
n_samples : int
|
||||
The number of integer to sample.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
If int, random_state is the seed used by the random number generator;
|
||||
If RandomState instance, random_state is the random number generator;
|
||||
If None, the random number generator is the RandomState instance used
|
||||
by `np.random`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray of shape (n_samples,)
|
||||
The sampled subsets of integer.
|
||||
"""
|
||||
_sample_without_replacement_check_input(n_population, n_samples)
|
||||
|
||||
cdef default_int i
|
||||
cdef default_int j
|
||||
cdef default_int[::1] out = np.empty((n_samples,), dtype=int)
|
||||
cdef default_int[::1] pool = np.empty((n_population,), dtype=int)
|
||||
|
||||
rng = check_random_state(random_state)
|
||||
rng_randint = rng.randint
|
||||
|
||||
# Initialize the pool
|
||||
for i in range(n_population):
|
||||
pool[i] = i
|
||||
|
||||
# The following line of code are heavily inspired from python core,
|
||||
# more precisely of random.sample.
|
||||
for i in range(n_samples):
|
||||
j = rng_randint(n_population - i) # invariant: non-selected at [0,n-i)
|
||||
out[i] = pool[j]
|
||||
pool[j] = pool[n_population - i - 1] # move non-selected item into vacancy
|
||||
|
||||
return np.asarray(out)
|
||||
|
||||
|
||||
cpdef _sample_without_replacement_with_reservoir_sampling(
|
||||
default_int n_population,
|
||||
default_int n_samples,
|
||||
random_state=None
|
||||
):
|
||||
"""Sample integers without replacement.
|
||||
|
||||
Select n_samples integers from the set [0, n_population) without
|
||||
replacement.
|
||||
|
||||
Time complexity of
|
||||
O((n_population - n_samples) * O(np.random.randint) + n_samples)
|
||||
Space complexity of O(n_samples)
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_population : int
|
||||
The size of the set to sample from.
|
||||
|
||||
n_samples : int
|
||||
The number of integer to sample.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
If int, random_state is the seed used by the random number generator;
|
||||
If RandomState instance, random_state is the random number generator;
|
||||
If None, the random number generator is the RandomState instance used
|
||||
by `np.random`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray of shape (n_samples,)
|
||||
The sampled subsets of integer. The order of the items is not
|
||||
necessarily random. Use a random permutation of the array if the order
|
||||
of the items has to be randomized.
|
||||
"""
|
||||
_sample_without_replacement_check_input(n_population, n_samples)
|
||||
|
||||
cdef default_int i
|
||||
cdef default_int j
|
||||
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
|
||||
|
||||
rng = check_random_state(random_state)
|
||||
rng_randint = rng.randint
|
||||
|
||||
# This cython implementation is based on the one of Robert Kern:
|
||||
# http://mail.scipy.org/pipermail/numpy-discussion/2010-December/
|
||||
# 054289.html
|
||||
#
|
||||
for i in range(n_samples):
|
||||
out[i] = i
|
||||
|
||||
for i from n_samples <= i < n_population:
|
||||
j = rng_randint(0, i + 1)
|
||||
if j < n_samples:
|
||||
out[j] = i
|
||||
|
||||
return np.asarray(out)
|
||||
|
||||
|
||||
cdef _sample_without_replacement(default_int n_population,
|
||||
default_int n_samples,
|
||||
method="auto",
|
||||
random_state=None):
|
||||
"""Sample integers without replacement.
|
||||
|
||||
Private function for the implementation, see sample_without_replacement
|
||||
documentation for more details.
|
||||
"""
|
||||
_sample_without_replacement_check_input(n_population, n_samples)
|
||||
|
||||
all_methods = ("auto", "tracking_selection", "reservoir_sampling", "pool")
|
||||
|
||||
ratio = <double> n_samples / n_population if n_population != 0.0 else 1.0
|
||||
|
||||
# Check ratio and use permutation unless ratio < 0.01 or ratio > 0.99
|
||||
if method == "auto" and ratio > 0.01 and ratio < 0.99:
|
||||
rng = check_random_state(random_state)
|
||||
return rng.permutation(n_population)[:n_samples]
|
||||
|
||||
if method == "auto" or method == "tracking_selection":
|
||||
# TODO the pool based method can also be used.
|
||||
# however, it requires special benchmark to take into account
|
||||
# the memory requirement of the array vs the set.
|
||||
|
||||
# The value 0.2 has been determined through benchmarking.
|
||||
if ratio < 0.2:
|
||||
return _sample_without_replacement_with_tracking_selection(
|
||||
n_population, n_samples, random_state)
|
||||
else:
|
||||
return _sample_without_replacement_with_reservoir_sampling(
|
||||
n_population, n_samples, random_state)
|
||||
|
||||
elif method == "reservoir_sampling":
|
||||
return _sample_without_replacement_with_reservoir_sampling(
|
||||
n_population, n_samples, random_state)
|
||||
|
||||
elif method == "pool":
|
||||
return _sample_without_replacement_with_pool(n_population, n_samples,
|
||||
random_state)
|
||||
else:
|
||||
raise ValueError('Expected a method name in %s, got %s. '
|
||||
% (all_methods, method))
|
||||
|
||||
|
||||
def sample_without_replacement(
|
||||
object n_population, object n_samples, method="auto", random_state=None):
|
||||
"""Sample integers without replacement.
|
||||
|
||||
Select n_samples integers from the set [0, n_population) without
|
||||
replacement.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_population : int
|
||||
The size of the set to sample from.
|
||||
|
||||
n_samples : int
|
||||
The number of integer to sample.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
If int, random_state is the seed used by the random number generator;
|
||||
If RandomState instance, random_state is the random number generator;
|
||||
If None, the random number generator is the RandomState instance used
|
||||
by `np.random`.
|
||||
|
||||
method : {"auto", "tracking_selection", "reservoir_sampling", "pool"}, \
|
||||
default='auto'
|
||||
If method == "auto", the ratio of n_samples / n_population is used
|
||||
to determine which algorithm to use:
|
||||
If ratio is between 0 and 0.01, tracking selection is used.
|
||||
If ratio is between 0.01 and 0.99, numpy.random.permutation is used.
|
||||
If ratio is greater than 0.99, reservoir sampling is used.
|
||||
The order of the selected integers is undefined. If a random order is
|
||||
desired, the selected subset should be shuffled.
|
||||
|
||||
If method =="tracking_selection", a set based implementation is used
|
||||
which is suitable for `n_samples` <<< `n_population`.
|
||||
|
||||
If method == "reservoir_sampling", a reservoir sampling algorithm is
|
||||
used which is suitable for high memory constraint or when
|
||||
O(`n_samples`) ~ O(`n_population`).
|
||||
The order of the selected integers is undefined. If a random order is
|
||||
desired, the selected subset should be shuffled.
|
||||
|
||||
If method == "pool", a pool based algorithm is particularly fast, even
|
||||
faster than the tracking selection method. However, a vector containing
|
||||
the entire population has to be initialized.
|
||||
If n_samples ~ n_population, the reservoir sampling method is faster.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray of shape (n_samples,)
|
||||
The sampled subsets of integer. The subset of selected integer might
|
||||
not be randomized, see the method argument.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils.random import sample_without_replacement
|
||||
>>> sample_without_replacement(10, 5, random_state=42)
|
||||
array([8, 1, 5, 0, 7])
|
||||
"""
|
||||
cdef:
|
||||
intp_t n_pop_intp, n_samples_intp
|
||||
long n_pop_long, n_samples_long
|
||||
|
||||
# On most platforms `np.int_ is np.intp`. However, before NumPy 2 the
|
||||
# default integer `np.int_` was a long which is 32bit on 64bit windows
|
||||
# while `intp` is 64bit on 64bit platforms and 32bit on 32bit ones.
|
||||
if np.int_ is np.intp:
|
||||
# Branch always taken on NumPy >=2 (or when not on 64bit windows).
|
||||
# Cython has different rules for conversion of values to integers.
|
||||
# For NumPy <1.26.2 AND Cython 3, this first branch requires `int()`
|
||||
# called explicitly to allow e.g. floats.
|
||||
n_pop_intp = int(n_population)
|
||||
n_samples_intp = int(n_samples)
|
||||
return _sample_without_replacement(
|
||||
n_pop_intp, n_samples_intp, method, random_state)
|
||||
else:
|
||||
# Branch taken on 64bit windows with Numpy<2.0 where `long` is 32bit
|
||||
n_pop_long = n_population
|
||||
n_samples_long = n_samples
|
||||
return _sample_without_replacement(
|
||||
n_pop_long, n_samples_long, method, random_state)
|
||||
|
||||
|
||||
def _our_rand_r_py(seed):
|
||||
"""Python utils to test the our_rand_r function"""
|
||||
cdef uint32_t my_seed = seed
|
||||
return our_rand_r(&my_seed)
|
||||
@@ -0,0 +1,2 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,152 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import itertools
|
||||
|
||||
from ... import __version__
|
||||
from ..._config import get_config
|
||||
from ..fixes import parse_version
|
||||
|
||||
|
||||
class _HTMLDocumentationLinkMixin:
|
||||
"""Mixin class allowing to generate a link to the API documentation.
|
||||
|
||||
This mixin relies on three attributes:
|
||||
- `_doc_link_module`: it corresponds to the root module (e.g. `sklearn`). Using this
|
||||
mixin, the default value is `sklearn`.
|
||||
- `_doc_link_template`: it corresponds to the template used to generate the
|
||||
link to the API documentation. Using this mixin, the default value is
|
||||
`"https://scikit-learn.org/{version_url}/modules/generated/
|
||||
{estimator_module}.{estimator_name}.html"`.
|
||||
- `_doc_link_url_param_generator`: it corresponds to a function that generates the
|
||||
parameters to be used in the template when the estimator module and name are not
|
||||
sufficient.
|
||||
|
||||
The method :meth:`_get_doc_link` generates the link to the API documentation for a
|
||||
given estimator.
|
||||
|
||||
This useful provides all the necessary states for
|
||||
:func:`sklearn.utils.estimator_html_repr` to generate a link to the API
|
||||
documentation for the estimator HTML diagram.
|
||||
|
||||
Examples
|
||||
--------
|
||||
If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
|
||||
then you can override them and provide a method to generate the URL parameters:
|
||||
>>> from sklearn.base import BaseEstimator
|
||||
>>> doc_link_template = "https://address.local/{single_param}.html"
|
||||
>>> def url_param_generator(estimator):
|
||||
... return {"single_param": estimator.__class__.__name__}
|
||||
>>> class MyEstimator(BaseEstimator):
|
||||
... # use "builtins" since it is the associated module when declaring
|
||||
... # the class in a docstring
|
||||
... _doc_link_module = "builtins"
|
||||
... _doc_link_template = doc_link_template
|
||||
... _doc_link_url_param_generator = url_param_generator
|
||||
>>> estimator = MyEstimator()
|
||||
>>> estimator._get_doc_link()
|
||||
'https://address.local/MyEstimator.html'
|
||||
|
||||
If instead of overriding the attributes inside the class definition, you want to
|
||||
override a class instance, you can use `types.MethodType` to bind the method to the
|
||||
instance:
|
||||
>>> import types
|
||||
>>> estimator = BaseEstimator()
|
||||
>>> estimator._doc_link_template = doc_link_template
|
||||
>>> estimator._doc_link_url_param_generator = types.MethodType(
|
||||
... url_param_generator, estimator)
|
||||
>>> estimator._get_doc_link()
|
||||
'https://address.local/BaseEstimator.html'
|
||||
"""
|
||||
|
||||
_doc_link_module = "sklearn"
|
||||
_doc_link_url_param_generator = None
|
||||
|
||||
@property
|
||||
def _doc_link_template(self):
|
||||
sklearn_version = parse_version(__version__)
|
||||
if sklearn_version.dev is None:
|
||||
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"
|
||||
else:
|
||||
version_url = "dev"
|
||||
return getattr(
|
||||
self,
|
||||
"__doc_link_template",
|
||||
(
|
||||
f"https://scikit-learn.org/{version_url}/modules/generated/"
|
||||
"{estimator_module}.{estimator_name}.html"
|
||||
),
|
||||
)
|
||||
|
||||
@_doc_link_template.setter
|
||||
def _doc_link_template(self, value):
|
||||
setattr(self, "__doc_link_template", value)
|
||||
|
||||
def _get_doc_link(self):
|
||||
"""Generates a link to the API documentation for a given estimator.
|
||||
|
||||
This method generates the link to the estimator's documentation page
|
||||
by using the template defined by the attribute `_doc_link_template`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
url : str
|
||||
The URL to the API documentation for this estimator. If the estimator does
|
||||
not belong to module `_doc_link_module`, the empty string (i.e. `""`) is
|
||||
returned.
|
||||
"""
|
||||
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
|
||||
return ""
|
||||
|
||||
if self._doc_link_url_param_generator is None:
|
||||
estimator_name = self.__class__.__name__
|
||||
# Construct the estimator's module name, up to the first private submodule.
|
||||
# This works because in scikit-learn all public estimators are exposed at
|
||||
# that level, even if they actually live in a private sub-module.
|
||||
estimator_module = ".".join(
|
||||
itertools.takewhile(
|
||||
lambda part: not part.startswith("_"),
|
||||
self.__class__.__module__.split("."),
|
||||
)
|
||||
)
|
||||
return self._doc_link_template.format(
|
||||
estimator_module=estimator_module, estimator_name=estimator_name
|
||||
)
|
||||
return self._doc_link_template.format(**self._doc_link_url_param_generator())
|
||||
|
||||
|
||||
class ReprHTMLMixin:
|
||||
"""Mixin to handle consistently the HTML representation.
|
||||
|
||||
When inheriting from this class, you need to define an attribute `_html_repr`
|
||||
which is a callable that returns the HTML representation to be shown.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _repr_html_(self):
|
||||
"""HTML representation of estimator.
|
||||
This is redundant with the logic of `_repr_mimebundle_`. The latter
|
||||
should be favored in the long term, `_repr_html_` is only
|
||||
implemented for consumers who do not interpret `_repr_mimbundle_`.
|
||||
"""
|
||||
if get_config()["display"] != "diagram":
|
||||
raise AttributeError(
|
||||
"_repr_html_ is only defined when the "
|
||||
"'display' configuration option is set to "
|
||||
"'diagram'"
|
||||
)
|
||||
return self._repr_html_inner
|
||||
|
||||
def _repr_html_inner(self):
|
||||
"""This function is returned by the @property `_repr_html_` to make
|
||||
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
|
||||
on `get_config()["display"]`.
|
||||
"""
|
||||
return self._html_repr()
|
||||
|
||||
def _repr_mimebundle_(self, **kwargs):
|
||||
"""Mime bundle used by jupyter kernels to display estimator"""
|
||||
output = {"text/plain": repr(self)}
|
||||
if get_config()["display"] == "diagram":
|
||||
output["text/html"] = self._html_repr()
|
||||
return output
|
||||
@@ -0,0 +1,413 @@
|
||||
#$id {
|
||||
/* Definition of color scheme common for light and dark mode */
|
||||
--sklearn-color-text: #000;
|
||||
--sklearn-color-text-muted: #666;
|
||||
--sklearn-color-line: gray;
|
||||
/* Definition of color scheme for unfitted estimators */
|
||||
--sklearn-color-unfitted-level-0: #fff5e6;
|
||||
--sklearn-color-unfitted-level-1: #f6e4d2;
|
||||
--sklearn-color-unfitted-level-2: #ffe0b3;
|
||||
--sklearn-color-unfitted-level-3: chocolate;
|
||||
/* Definition of color scheme for fitted estimators */
|
||||
--sklearn-color-fitted-level-0: #f0f8ff;
|
||||
--sklearn-color-fitted-level-1: #d4ebff;
|
||||
--sklearn-color-fitted-level-2: #b3dbfd;
|
||||
--sklearn-color-fitted-level-3: cornflowerblue;
|
||||
|
||||
/* Specific color for light theme */
|
||||
--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));
|
||||
--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));
|
||||
--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));
|
||||
--sklearn-color-icon: #696969;
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
/* Redefinition of color scheme for dark theme */
|
||||
--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));
|
||||
--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));
|
||||
--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));
|
||||
--sklearn-color-icon: #878787;
|
||||
}
|
||||
}
|
||||
|
||||
#$id {
|
||||
color: var(--sklearn-color-text);
|
||||
}
|
||||
|
||||
#$id pre {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
#$id input.sk-hidden--visually {
|
||||
border: 0;
|
||||
clip: rect(1px 1px 1px 1px);
|
||||
clip: rect(1px, 1px, 1px, 1px);
|
||||
height: 1px;
|
||||
margin: -1px;
|
||||
overflow: hidden;
|
||||
padding: 0;
|
||||
position: absolute;
|
||||
width: 1px;
|
||||
}
|
||||
|
||||
#$id div.sk-dashed-wrapped {
|
||||
border: 1px dashed var(--sklearn-color-line);
|
||||
margin: 0 0.4em 0.5em 0.4em;
|
||||
box-sizing: border-box;
|
||||
padding-bottom: 0.4em;
|
||||
background-color: var(--sklearn-color-background);
|
||||
}
|
||||
|
||||
#$id div.sk-container {
|
||||
/* jupyter's `normalize.less` sets `[hidden] { display: none; }`
|
||||
but bootstrap.min.css set `[hidden] { display: none !important; }`
|
||||
so we also need the `!important` here to be able to override the
|
||||
default hidden behavior on the sphinx rendered scikit-learn.org.
|
||||
See: https://github.com/scikit-learn/scikit-learn/issues/21755 */
|
||||
display: inline-block !important;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#$id div.sk-text-repr-fallback {
|
||||
display: none;
|
||||
}
|
||||
|
||||
div.sk-parallel-item,
|
||||
div.sk-serial,
|
||||
div.sk-item {
|
||||
/* draw centered vertical line to link estimators */
|
||||
background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));
|
||||
background-size: 2px 100%;
|
||||
background-repeat: no-repeat;
|
||||
background-position: center center;
|
||||
}
|
||||
|
||||
/* Parallel-specific style estimator block */
|
||||
|
||||
#$id div.sk-parallel-item::after {
|
||||
content: "";
|
||||
width: 100%;
|
||||
border-bottom: 2px solid var(--sklearn-color-text-on-default-background);
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#$id div.sk-parallel {
|
||||
display: flex;
|
||||
align-items: stretch;
|
||||
justify-content: center;
|
||||
background-color: var(--sklearn-color-background);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#$id div.sk-parallel-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
#$id div.sk-parallel-item:first-child::after {
|
||||
align-self: flex-end;
|
||||
width: 50%;
|
||||
}
|
||||
|
||||
#$id div.sk-parallel-item:last-child::after {
|
||||
align-self: flex-start;
|
||||
width: 50%;
|
||||
}
|
||||
|
||||
#$id div.sk-parallel-item:only-child::after {
|
||||
width: 0;
|
||||
}
|
||||
|
||||
/* Serial-specific style estimator block */
|
||||
|
||||
#$id div.sk-serial {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
background-color: var(--sklearn-color-background);
|
||||
padding-right: 1em;
|
||||
padding-left: 1em;
|
||||
}
|
||||
|
||||
|
||||
/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is
|
||||
clickable and can be expanded/collapsed.
|
||||
- Pipeline and ColumnTransformer use this feature and define the default style
|
||||
- Estimators will overwrite some part of the style using the `sk-estimator` class
|
||||
*/
|
||||
|
||||
/* Pipeline and ColumnTransformer style (default) */
|
||||
|
||||
#$id div.sk-toggleable {
|
||||
/* Default theme specific background. It is overwritten whether we have a
|
||||
specific estimator or a Pipeline/ColumnTransformer */
|
||||
background-color: var(--sklearn-color-background);
|
||||
}
|
||||
|
||||
/* Toggleable label */
|
||||
#$id label.sk-toggleable__label {
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
margin-bottom: 0;
|
||||
padding: 0.5em;
|
||||
box-sizing: border-box;
|
||||
text-align: center;
|
||||
align-items: start;
|
||||
justify-content: space-between;
|
||||
gap: 0.5em;
|
||||
}
|
||||
|
||||
#$id label.sk-toggleable__label .caption {
|
||||
font-size: 0.6rem;
|
||||
font-weight: lighter;
|
||||
color: var(--sklearn-color-text-muted);
|
||||
}
|
||||
|
||||
#$id label.sk-toggleable__label-arrow:before {
|
||||
/* Arrow on the left of the label */
|
||||
content: "▸";
|
||||
float: left;
|
||||
margin-right: 0.25em;
|
||||
color: var(--sklearn-color-icon);
|
||||
}
|
||||
|
||||
#$id label.sk-toggleable__label-arrow:hover:before {
|
||||
color: var(--sklearn-color-text);
|
||||
}
|
||||
|
||||
/* Toggleable content - dropdown */
|
||||
|
||||
#$id div.sk-toggleable__content {
|
||||
display: none;
|
||||
text-align: left;
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-0);
|
||||
}
|
||||
|
||||
#$id div.sk-toggleable__content.fitted {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-0);
|
||||
}
|
||||
|
||||
#$id div.sk-toggleable__content pre {
|
||||
margin: 0.2em;
|
||||
border-radius: 0.25em;
|
||||
color: var(--sklearn-color-text);
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-0);
|
||||
}
|
||||
|
||||
#$id div.sk-toggleable__content.fitted pre {
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-fitted-level-0);
|
||||
}
|
||||
|
||||
#$id input.sk-toggleable__control:checked~div.sk-toggleable__content {
|
||||
/* Expand drop-down */
|
||||
display: block;
|
||||
width: 100%;
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
#$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {
|
||||
content: "▾";
|
||||
}
|
||||
|
||||
/* Pipeline/ColumnTransformer-specific style */
|
||||
|
||||
#$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {
|
||||
color: var(--sklearn-color-text);
|
||||
background-color: var(--sklearn-color-unfitted-level-2);
|
||||
}
|
||||
|
||||
#$id div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {
|
||||
background-color: var(--sklearn-color-fitted-level-2);
|
||||
}
|
||||
|
||||
/* Estimator-specific style */
|
||||
|
||||
/* Colorize estimator box */
|
||||
#$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-2);
|
||||
}
|
||||
|
||||
#$id div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-2);
|
||||
}
|
||||
|
||||
#$id div.sk-label label.sk-toggleable__label,
|
||||
#$id div.sk-label label {
|
||||
/* The background is the default theme color */
|
||||
color: var(--sklearn-color-text-on-default-background);
|
||||
}
|
||||
|
||||
/* On hover, darken the color of the background */
|
||||
#$id div.sk-label:hover label.sk-toggleable__label {
|
||||
color: var(--sklearn-color-text);
|
||||
background-color: var(--sklearn-color-unfitted-level-2);
|
||||
}
|
||||
|
||||
/* Label box, darken color on hover, fitted */
|
||||
#$id div.sk-label.fitted:hover label.sk-toggleable__label.fitted {
|
||||
color: var(--sklearn-color-text);
|
||||
background-color: var(--sklearn-color-fitted-level-2);
|
||||
}
|
||||
|
||||
/* Estimator label */
|
||||
|
||||
#$id div.sk-label label {
|
||||
font-family: monospace;
|
||||
font-weight: bold;
|
||||
display: inline-block;
|
||||
line-height: 1.2em;
|
||||
}
|
||||
|
||||
#$id div.sk-label-container {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Estimator-specific */
|
||||
#$id div.sk-estimator {
|
||||
font-family: monospace;
|
||||
border: 1px dotted var(--sklearn-color-border-box);
|
||||
border-radius: 0.25em;
|
||||
box-sizing: border-box;
|
||||
margin-bottom: 0.5em;
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-0);
|
||||
}
|
||||
|
||||
#$id div.sk-estimator.fitted {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-0);
|
||||
}
|
||||
|
||||
/* on hover */
|
||||
#$id div.sk-estimator:hover {
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-2);
|
||||
}
|
||||
|
||||
#$id div.sk-estimator.fitted:hover {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-2);
|
||||
}
|
||||
|
||||
/* Specification for estimator info (e.g. "i" and "?") */
|
||||
|
||||
/* Common style for "i" and "?" */
|
||||
|
||||
.sk-estimator-doc-link,
|
||||
a:link.sk-estimator-doc-link,
|
||||
a:visited.sk-estimator-doc-link {
|
||||
float: right;
|
||||
font-size: smaller;
|
||||
line-height: 1em;
|
||||
font-family: monospace;
|
||||
background-color: var(--sklearn-color-background);
|
||||
border-radius: 1em;
|
||||
height: 1em;
|
||||
width: 1em;
|
||||
text-decoration: none !important;
|
||||
margin-left: 0.5em;
|
||||
text-align: center;
|
||||
/* unfitted */
|
||||
border: var(--sklearn-color-unfitted-level-1) 1pt solid;
|
||||
color: var(--sklearn-color-unfitted-level-1);
|
||||
}
|
||||
|
||||
.sk-estimator-doc-link.fitted,
|
||||
a:link.sk-estimator-doc-link.fitted,
|
||||
a:visited.sk-estimator-doc-link.fitted {
|
||||
/* fitted */
|
||||
border: var(--sklearn-color-fitted-level-1) 1pt solid;
|
||||
color: var(--sklearn-color-fitted-level-1);
|
||||
}
|
||||
|
||||
/* On hover */
|
||||
div.sk-estimator:hover .sk-estimator-doc-link:hover,
|
||||
.sk-estimator-doc-link:hover,
|
||||
div.sk-label-container:hover .sk-estimator-doc-link:hover,
|
||||
.sk-estimator-doc-link:hover {
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-3);
|
||||
color: var(--sklearn-color-background);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,
|
||||
.sk-estimator-doc-link.fitted:hover,
|
||||
div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,
|
||||
.sk-estimator-doc-link.fitted:hover {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-3);
|
||||
color: var(--sklearn-color-background);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* Span, style for the box shown on hovering the info icon */
|
||||
.sk-estimator-doc-link span {
|
||||
display: none;
|
||||
z-index: 9999;
|
||||
position: relative;
|
||||
font-weight: normal;
|
||||
right: .2ex;
|
||||
padding: .5ex;
|
||||
margin: .5ex;
|
||||
width: min-content;
|
||||
min-width: 20ex;
|
||||
max-width: 50ex;
|
||||
color: var(--sklearn-color-text);
|
||||
box-shadow: 2pt 2pt 4pt #999;
|
||||
/* unfitted */
|
||||
background: var(--sklearn-color-unfitted-level-0);
|
||||
border: .5pt solid var(--sklearn-color-unfitted-level-3);
|
||||
}
|
||||
|
||||
.sk-estimator-doc-link.fitted span {
|
||||
/* fitted */
|
||||
background: var(--sklearn-color-fitted-level-0);
|
||||
border: var(--sklearn-color-fitted-level-3);
|
||||
}
|
||||
|
||||
.sk-estimator-doc-link:hover span {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* "?"-specific style due to the `<a>` HTML tag */
|
||||
|
||||
#$id a.estimator_doc_link {
|
||||
float: right;
|
||||
font-size: 1rem;
|
||||
line-height: 1em;
|
||||
font-family: monospace;
|
||||
background-color: var(--sklearn-color-background);
|
||||
border-radius: 1rem;
|
||||
height: 1rem;
|
||||
width: 1rem;
|
||||
text-decoration: none;
|
||||
/* unfitted */
|
||||
color: var(--sklearn-color-unfitted-level-1);
|
||||
border: var(--sklearn-color-unfitted-level-1) 1pt solid;
|
||||
}
|
||||
|
||||
#$id a.estimator_doc_link.fitted {
|
||||
/* fitted */
|
||||
border: var(--sklearn-color-fitted-level-1) 1pt solid;
|
||||
color: var(--sklearn-color-fitted-level-1);
|
||||
}
|
||||
|
||||
/* On hover */
|
||||
#$id a.estimator_doc_link:hover {
|
||||
/* unfitted */
|
||||
background-color: var(--sklearn-color-unfitted-level-3);
|
||||
color: var(--sklearn-color-background);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
#$id a.estimator_doc_link.fitted:hover {
|
||||
/* fitted */
|
||||
background-color: var(--sklearn-color-fitted-level-3);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
function copyToClipboard(text, element) {
|
||||
// Get the parameter prefix from the closest toggleable content
|
||||
const toggleableContent = element.closest('.sk-toggleable__content');
|
||||
const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';
|
||||
const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text;
|
||||
|
||||
const originalStyle = element.style;
|
||||
const computedStyle = window.getComputedStyle(element);
|
||||
const originalWidth = computedStyle.width;
|
||||
const originalHTML = element.innerHTML.replace('Copied!', '');
|
||||
|
||||
navigator.clipboard.writeText(fullParamName)
|
||||
.then(() => {
|
||||
element.style.width = originalWidth;
|
||||
element.style.color = 'green';
|
||||
element.innerHTML = "Copied!";
|
||||
|
||||
setTimeout(() => {
|
||||
element.innerHTML = originalHTML;
|
||||
element.style = originalStyle;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(err => {
|
||||
console.error('Failed to copy:', err);
|
||||
element.style.color = 'red';
|
||||
element.innerHTML = "Failed!";
|
||||
setTimeout(() => {
|
||||
element.innerHTML = originalHTML;
|
||||
element.style = originalStyle;
|
||||
}, 2000);
|
||||
});
|
||||
return false;
|
||||
}
|
||||
|
||||
document.querySelectorAll('.fa-regular.fa-copy').forEach(function(element) {
|
||||
const toggleableContent = element.closest('.sk-toggleable__content');
|
||||
const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';
|
||||
const paramName = element.parentElement.nextElementSibling.textContent.trim();
|
||||
const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName;
|
||||
|
||||
element.setAttribute('title', fullParamName);
|
||||
});
|
||||
@@ -0,0 +1,497 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import html
|
||||
from contextlib import closing
|
||||
from inspect import isclass
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
|
||||
from ... import config_context
|
||||
|
||||
|
||||
class _IDCounter:
|
||||
"""Generate sequential ids with a prefix."""
|
||||
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
self.count = 0
|
||||
|
||||
def get_id(self):
|
||||
self.count += 1
|
||||
return f"{self.prefix}-{self.count}"
|
||||
|
||||
|
||||
def _get_css_style():
|
||||
estimator_css_file = Path(__file__).parent / "estimator.css"
|
||||
params_css_file = Path(__file__).parent / "params.css"
|
||||
|
||||
estimator_css = estimator_css_file.read_text(encoding="utf-8")
|
||||
params_css = params_css_file.read_text(encoding="utf-8")
|
||||
|
||||
return f"{estimator_css}\n{params_css}"
|
||||
|
||||
|
||||
_CONTAINER_ID_COUNTER = _IDCounter("sk-container-id")
|
||||
_ESTIMATOR_ID_COUNTER = _IDCounter("sk-estimator-id")
|
||||
_CSS_STYLE = _get_css_style()
|
||||
|
||||
|
||||
class _VisualBlock:
|
||||
"""HTML Representation of Estimator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kind : {'serial', 'parallel', 'single'}
|
||||
kind of HTML block
|
||||
|
||||
estimators : list of estimators or `_VisualBlock`s or a single estimator
|
||||
If kind != 'single', then `estimators` is a list of
|
||||
estimators.
|
||||
If kind == 'single', then `estimators` is a single estimator.
|
||||
|
||||
names : list of str, default=None
|
||||
If kind != 'single', then `names` corresponds to estimators.
|
||||
If kind == 'single', then `names` is a single string corresponding to
|
||||
the single estimator.
|
||||
|
||||
name_details : list of str, str, or None, default=None
|
||||
If kind != 'single', then `name_details` corresponds to `names`.
|
||||
If kind == 'single', then `name_details` is a single string
|
||||
corresponding to the single estimator.
|
||||
|
||||
name_caption : str, default=None
|
||||
The caption below the name. `None` stands for no caption.
|
||||
Only active when kind == 'single'.
|
||||
|
||||
doc_link_label : str, default=None
|
||||
The label for the documentation link. If provided, the label would be
|
||||
"Documentation for {doc_link_label}". Otherwise it will look for `names`.
|
||||
Only active when kind == 'single'.
|
||||
|
||||
dash_wrapped : bool, default=True
|
||||
If true, wrapped HTML element will be wrapped with a dashed border.
|
||||
Only active when kind != 'single'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kind,
|
||||
estimators,
|
||||
*,
|
||||
names=None,
|
||||
name_details=None,
|
||||
name_caption=None,
|
||||
doc_link_label=None,
|
||||
dash_wrapped=True,
|
||||
):
|
||||
self.kind = kind
|
||||
self.estimators = estimators
|
||||
self.dash_wrapped = dash_wrapped
|
||||
self.name_caption = name_caption
|
||||
self.doc_link_label = doc_link_label
|
||||
|
||||
if self.kind in ("parallel", "serial"):
|
||||
if names is None:
|
||||
names = (None,) * len(estimators)
|
||||
if name_details is None:
|
||||
name_details = (None,) * len(estimators)
|
||||
|
||||
self.names = names
|
||||
self.name_details = name_details
|
||||
|
||||
def _sk_visual_block_(self):
|
||||
return self
|
||||
|
||||
|
||||
def _write_label_html(
|
||||
out,
|
||||
params,
|
||||
name,
|
||||
name_details,
|
||||
name_caption=None,
|
||||
doc_link_label=None,
|
||||
outer_class="sk-label-container",
|
||||
inner_class="sk-label",
|
||||
checked=False,
|
||||
doc_link="",
|
||||
is_fitted_css_class="",
|
||||
is_fitted_icon="",
|
||||
param_prefix="",
|
||||
):
|
||||
"""Write labeled html with or without a dropdown with named details.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out : file-like object
|
||||
The file to write the HTML representation to.
|
||||
params: str
|
||||
If estimator has `get_params` method, this is the HTML representation
|
||||
of the estimator's parameters and their values. When the estimator
|
||||
does not have `get_params`, it is an empty string.
|
||||
name : str
|
||||
The label for the estimator. It corresponds either to the estimator class name
|
||||
for a simple estimator or in the case of a `Pipeline` and `ColumnTransformer`,
|
||||
it corresponds to the name of the step.
|
||||
name_details : str
|
||||
The details to show as content in the dropdown part of the toggleable label. It
|
||||
can contain information such as non-default parameters or column information for
|
||||
`ColumnTransformer`.
|
||||
name_caption : str, default=None
|
||||
The caption below the name. If `None`, no caption will be created.
|
||||
doc_link_label : str, default=None
|
||||
The label for the documentation link. If provided, the label would be
|
||||
"Documentation for {doc_link_label}". Otherwise it will look for `name`.
|
||||
outer_class : {"sk-label-container", "sk-item"}, default="sk-label-container"
|
||||
The CSS class for the outer container.
|
||||
inner_class : {"sk-label", "sk-estimator"}, default="sk-label"
|
||||
The CSS class for the inner container.
|
||||
checked : bool, default=False
|
||||
Whether the dropdown is folded or not. With a single estimator, we intend to
|
||||
unfold the content.
|
||||
doc_link : str, default=""
|
||||
The link to the documentation for the estimator. If an empty string, no link is
|
||||
added to the diagram. This can be generated for an estimator if it uses the
|
||||
`_HTMLDocumentationLinkMixin`.
|
||||
is_fitted_css_class : {"", "fitted"}
|
||||
The CSS class to indicate whether or not the estimator is fitted. The
|
||||
empty string means that the estimator is not fitted and "fitted" means that the
|
||||
estimator is fitted.
|
||||
is_fitted_icon : str, default=""
|
||||
The HTML representation to show the fitted information in the diagram. An empty
|
||||
string means that no information is shown.
|
||||
param_prefix : str, default=""
|
||||
The prefix to prepend to parameter names for nested estimators.
|
||||
"""
|
||||
out.write(
|
||||
f'<div class="{outer_class}"><div'
|
||||
f' class="{inner_class} {is_fitted_css_class} sk-toggleable">'
|
||||
)
|
||||
name = html.escape(name)
|
||||
if name_details is not None:
|
||||
name_details = html.escape(str(name_details))
|
||||
checked_str = "checked" if checked else ""
|
||||
est_id = _ESTIMATOR_ID_COUNTER.get_id()
|
||||
|
||||
if doc_link:
|
||||
doc_label = "<span>Online documentation</span>"
|
||||
if doc_link_label is not None:
|
||||
doc_label = f"<span>Documentation for {doc_link_label}</span>"
|
||||
elif name is not None:
|
||||
doc_label = f"<span>Documentation for {name}</span>"
|
||||
doc_link = (
|
||||
f'<a class="sk-estimator-doc-link {is_fitted_css_class}"'
|
||||
f' rel="noreferrer" target="_blank" href="{doc_link}">?{doc_label}</a>'
|
||||
)
|
||||
|
||||
name_caption_div = (
|
||||
""
|
||||
if name_caption is None
|
||||
else f'<div class="caption">{html.escape(name_caption)}</div>'
|
||||
)
|
||||
name_caption_div = f"<div><div>{name}</div>{name_caption_div}</div>"
|
||||
links_div = (
|
||||
f"<div>{doc_link}{is_fitted_icon}</div>"
|
||||
if doc_link or is_fitted_icon
|
||||
else ""
|
||||
)
|
||||
|
||||
label_html = (
|
||||
f'<label for="{est_id}" class="sk-toggleable__label {is_fitted_css_class} '
|
||||
f'sk-toggleable__label-arrow">{name_caption_div}{links_div}</label>'
|
||||
)
|
||||
|
||||
fmt_str = (
|
||||
f'<input class="sk-toggleable__control sk-hidden--visually" id="{est_id}" '
|
||||
f'type="checkbox" {checked_str}>{label_html}<div '
|
||||
f'class="sk-toggleable__content {is_fitted_css_class}" '
|
||||
f'data-param-prefix="{html.escape(param_prefix)}">'
|
||||
)
|
||||
|
||||
if params:
|
||||
fmt_str = "".join([fmt_str, f"{params}</div>"])
|
||||
elif name_details and ("Pipeline" not in name):
|
||||
fmt_str = "".join([fmt_str, f"<pre>{name_details}</pre></div>"])
|
||||
|
||||
out.write(fmt_str)
|
||||
else:
|
||||
out.write(f"<label>{name}</label>")
|
||||
out.write("</div></div>") # outer_class inner_class
|
||||
|
||||
|
||||
def _get_visual_block(estimator):
|
||||
"""Generate information about how to display an estimator."""
|
||||
if hasattr(estimator, "_sk_visual_block_"):
|
||||
try:
|
||||
return estimator._sk_visual_block_()
|
||||
except Exception:
|
||||
return _VisualBlock(
|
||||
"single",
|
||||
estimator,
|
||||
names=estimator.__class__.__name__,
|
||||
name_details=str(estimator),
|
||||
)
|
||||
|
||||
if isinstance(estimator, str):
|
||||
return _VisualBlock(
|
||||
"single", estimator, names=estimator, name_details=estimator
|
||||
)
|
||||
elif estimator is None:
|
||||
return _VisualBlock("single", estimator, names="None", name_details="None")
|
||||
|
||||
# check if estimator looks like a meta estimator (wraps estimators)
|
||||
if hasattr(estimator, "get_params") and not isclass(estimator):
|
||||
estimators = [
|
||||
(key, est)
|
||||
for key, est in estimator.get_params(deep=False).items()
|
||||
if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
|
||||
]
|
||||
if estimators:
|
||||
return _VisualBlock(
|
||||
"parallel",
|
||||
[est for _, est in estimators],
|
||||
names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
|
||||
name_details=[str(est) for _, est in estimators],
|
||||
)
|
||||
|
||||
return _VisualBlock(
|
||||
"single",
|
||||
estimator,
|
||||
names=estimator.__class__.__name__,
|
||||
name_details=str(estimator),
|
||||
)
|
||||
|
||||
|
||||
def _write_estimator_html(
|
||||
out,
|
||||
estimator,
|
||||
estimator_label,
|
||||
estimator_label_details,
|
||||
is_fitted_css_class,
|
||||
is_fitted_icon="",
|
||||
first_call=False,
|
||||
param_prefix="",
|
||||
):
|
||||
"""Write estimator to html in serial, parallel, or by itself (single).
|
||||
|
||||
For multiple estimators, this function is called recursively.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out : file-like object
|
||||
The file to write the HTML representation to.
|
||||
estimator : estimator object
|
||||
The estimator to visualize.
|
||||
estimator_label : str
|
||||
The label for the estimator. It corresponds either to the estimator class name
|
||||
for simple estimator or in the case of `Pipeline` and `ColumnTransformer`, it
|
||||
corresponds to the name of the step.
|
||||
estimator_label_details : str
|
||||
The details to show as content in the dropdown part of the toggleable label.
|
||||
It can contain information as non-default parameters or column information for
|
||||
`ColumnTransformer`.
|
||||
is_fitted_css_class : {"", "fitted"}
|
||||
The CSS class to indicate whether or not the estimator is fitted or not. The
|
||||
empty string means that the estimator is not fitted and "fitted" means that the
|
||||
estimator is fitted.
|
||||
is_fitted_icon : str, default=""
|
||||
The HTML representation to show the fitted information in the diagram. An empty
|
||||
string means that no information is shown. If the estimator to be shown is not
|
||||
the first estimator (i.e. `first_call=False`), `is_fitted_icon` is always an
|
||||
empty string.
|
||||
first_call : bool, default=False
|
||||
Whether this is the first time this function is called.
|
||||
param_prefix : str, default=""
|
||||
The prefix to prepend to parameter names for nested estimators.
|
||||
For example, in a pipeline this might be "pipeline__stepname__".
|
||||
"""
|
||||
if first_call:
|
||||
est_block = _get_visual_block(estimator)
|
||||
else:
|
||||
is_fitted_icon = ""
|
||||
with config_context(print_changed_only=True):
|
||||
est_block = _get_visual_block(estimator)
|
||||
# `estimator` can also be an instance of `_VisualBlock`
|
||||
if hasattr(estimator, "_get_doc_link"):
|
||||
doc_link = estimator._get_doc_link()
|
||||
else:
|
||||
doc_link = ""
|
||||
if est_block.kind in ("serial", "parallel"):
|
||||
dashed_wrapped = first_call or est_block.dash_wrapped
|
||||
dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
|
||||
out.write(f'<div class="sk-item{dash_cls}">')
|
||||
|
||||
if estimator_label:
|
||||
if hasattr(estimator, "get_params") and hasattr(
|
||||
estimator, "_get_params_html"
|
||||
):
|
||||
params = estimator._get_params_html(deep=False)._repr_html_inner()
|
||||
else:
|
||||
params = ""
|
||||
|
||||
_write_label_html(
|
||||
out,
|
||||
params,
|
||||
estimator_label,
|
||||
estimator_label_details,
|
||||
doc_link=doc_link,
|
||||
is_fitted_css_class=is_fitted_css_class,
|
||||
is_fitted_icon=is_fitted_icon,
|
||||
param_prefix=param_prefix,
|
||||
)
|
||||
|
||||
kind = est_block.kind
|
||||
out.write(f'<div class="sk-{kind}">')
|
||||
est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
|
||||
|
||||
for est, name, name_details in est_infos:
|
||||
# Build the parameter prefix for nested estimators
|
||||
|
||||
if param_prefix and hasattr(name, "split"):
|
||||
# If we already have a prefix, append the new component
|
||||
new_prefix = f"{param_prefix}{name.split(':')[0]}__"
|
||||
elif hasattr(name, "split"):
|
||||
# If this is the first level, start the prefix
|
||||
new_prefix = f"{name.split(':')[0]}__" if name else ""
|
||||
else:
|
||||
new_prefix = param_prefix
|
||||
|
||||
if kind == "serial":
|
||||
_write_estimator_html(
|
||||
out,
|
||||
est,
|
||||
name,
|
||||
name_details,
|
||||
is_fitted_css_class=is_fitted_css_class,
|
||||
param_prefix=new_prefix,
|
||||
)
|
||||
else: # parallel
|
||||
out.write('<div class="sk-parallel-item">')
|
||||
# wrap element in a serial visualblock
|
||||
serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
|
||||
_write_estimator_html(
|
||||
out,
|
||||
serial_block,
|
||||
name,
|
||||
name_details,
|
||||
is_fitted_css_class=is_fitted_css_class,
|
||||
param_prefix=new_prefix,
|
||||
)
|
||||
out.write("</div>") # sk-parallel-item
|
||||
|
||||
out.write("</div></div>")
|
||||
elif est_block.kind == "single":
|
||||
if hasattr(estimator, "_get_params_html"):
|
||||
params = estimator._get_params_html()._repr_html_inner()
|
||||
else:
|
||||
params = ""
|
||||
|
||||
_write_label_html(
|
||||
out,
|
||||
params,
|
||||
est_block.names,
|
||||
est_block.name_details,
|
||||
est_block.name_caption,
|
||||
est_block.doc_link_label,
|
||||
outer_class="sk-item",
|
||||
inner_class="sk-estimator",
|
||||
checked=first_call,
|
||||
doc_link=doc_link,
|
||||
is_fitted_css_class=is_fitted_css_class,
|
||||
is_fitted_icon=is_fitted_icon,
|
||||
param_prefix=param_prefix,
|
||||
)
|
||||
|
||||
|
||||
def estimator_html_repr(estimator):
|
||||
"""Build a HTML representation of an estimator.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizing_composite_estimators>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator object
|
||||
The estimator to visualize.
|
||||
|
||||
Returns
|
||||
-------
|
||||
html: str
|
||||
HTML representation of estimator.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.utils._repr_html.estimator import estimator_html_repr
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> estimator_html_repr(LogisticRegression())
|
||||
'<style>#sk-container-id...'
|
||||
"""
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.utils.validation import check_is_fitted
|
||||
|
||||
if not hasattr(estimator, "fit"):
|
||||
status_label = "<span>Not fitted</span>"
|
||||
is_fitted_css_class = ""
|
||||
else:
|
||||
try:
|
||||
check_is_fitted(estimator)
|
||||
status_label = "<span>Fitted</span>"
|
||||
is_fitted_css_class = "fitted"
|
||||
except NotFittedError:
|
||||
status_label = "<span>Not fitted</span>"
|
||||
is_fitted_css_class = ""
|
||||
|
||||
is_fitted_icon = (
|
||||
f'<span class="sk-estimator-doc-link {is_fitted_css_class}">'
|
||||
f"i{status_label}</span>"
|
||||
)
|
||||
with closing(StringIO()) as out:
|
||||
container_id = _CONTAINER_ID_COUNTER.get_id()
|
||||
style_template = Template(_CSS_STYLE)
|
||||
style_with_id = style_template.substitute(id=container_id)
|
||||
estimator_str = str(estimator)
|
||||
|
||||
# The fallback message is shown by default and loading the CSS sets
|
||||
# div.sk-text-repr-fallback to display: none to hide the fallback message.
|
||||
#
|
||||
# If the notebook is trusted, the CSS is loaded which hides the fallback
|
||||
# message. If the notebook is not trusted, then the CSS is not loaded and the
|
||||
# fallback message is shown by default.
|
||||
#
|
||||
# The reverse logic applies to HTML repr div.sk-container.
|
||||
# div.sk-container is hidden by default and the loading the CSS displays it.
|
||||
fallback_msg = (
|
||||
"In a Jupyter environment, please rerun this cell to show the HTML"
|
||||
" representation or trust the notebook. <br />On GitHub, the"
|
||||
" HTML representation is unable to render, please try loading this page"
|
||||
" with nbviewer.org."
|
||||
)
|
||||
html_template = (
|
||||
f"<style>{style_with_id}</style>"
|
||||
f"<body>"
|
||||
f'<div id="{container_id}" class="sk-top-container">'
|
||||
'<div class="sk-text-repr-fallback">'
|
||||
f"<pre>{html.escape(estimator_str)}</pre><b>{fallback_msg}</b>"
|
||||
"</div>"
|
||||
'<div class="sk-container" hidden>'
|
||||
)
|
||||
|
||||
out.write(html_template)
|
||||
_write_estimator_html(
|
||||
out,
|
||||
estimator,
|
||||
estimator.__class__.__name__,
|
||||
estimator_str,
|
||||
first_call=True,
|
||||
is_fitted_css_class=is_fitted_css_class,
|
||||
is_fitted_icon=is_fitted_icon,
|
||||
)
|
||||
with open(str(Path(__file__).parent / "estimator.js"), "r") as f:
|
||||
script = f.read()
|
||||
|
||||
html_end = f"</div></div><script>{script}</script></body>"
|
||||
|
||||
out.write(html_end)
|
||||
|
||||
html_output = out.getvalue()
|
||||
return html_output
|
||||
@@ -0,0 +1,63 @@
|
||||
.estimator-table summary {
|
||||
padding: .5rem;
|
||||
font-family: monospace;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.estimator-table details[open] {
|
||||
padding-left: 0.1rem;
|
||||
padding-right: 0.1rem;
|
||||
padding-bottom: 0.3rem;
|
||||
}
|
||||
|
||||
.estimator-table .parameters-table {
|
||||
margin-left: auto !important;
|
||||
margin-right: auto !important;
|
||||
}
|
||||
|
||||
.estimator-table .parameters-table tr:nth-child(odd) {
|
||||
background-color: #fff;
|
||||
}
|
||||
|
||||
.estimator-table .parameters-table tr:nth-child(even) {
|
||||
background-color: #f6f6f6;
|
||||
}
|
||||
|
||||
.estimator-table .parameters-table tr:hover {
|
||||
background-color: #e0e0e0;
|
||||
}
|
||||
|
||||
.estimator-table table td {
|
||||
border: 1px solid rgba(106, 105, 104, 0.232);
|
||||
}
|
||||
|
||||
.user-set td {
|
||||
color:rgb(255, 94, 0);
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.user-set td.value pre {
|
||||
color:rgb(255, 94, 0) !important;
|
||||
background-color: transparent !important;
|
||||
}
|
||||
|
||||
.default td {
|
||||
color: black;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.user-set td i,
|
||||
.default td i {
|
||||
color: black;
|
||||
}
|
||||
|
||||
.copy-paste-icon {
|
||||
background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=);
|
||||
background-repeat: no-repeat;
|
||||
background-size: 14px 14px;
|
||||
background-position: 0;
|
||||
display: inline-block;
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import html
|
||||
import reprlib
|
||||
from collections import UserDict
|
||||
|
||||
from sklearn.utils._repr_html.base import ReprHTMLMixin
|
||||
|
||||
|
||||
def _read_params(name, value, non_default_params):
|
||||
"""Categorizes parameters as 'default' or 'user-set' and formats their values.
|
||||
Escapes or truncates parameter values for display safety and readability.
|
||||
"""
|
||||
r = reprlib.Repr()
|
||||
r.maxlist = 2 # Show only first 2 items of lists
|
||||
r.maxtuple = 1 # Show only first item of tuples
|
||||
r.maxstring = 50 # Limit string length
|
||||
cleaned_value = html.escape(r.repr(value))
|
||||
|
||||
param_type = "user-set" if name in non_default_params else "default"
|
||||
|
||||
return {"param_type": param_type, "param_name": name, "param_value": cleaned_value}
|
||||
|
||||
|
||||
def _params_html_repr(params):
|
||||
"""Generate HTML representation of estimator parameters.
|
||||
|
||||
Creates an HTML table with parameter names and values, wrapped in a
|
||||
collapsible details element. Parameters are styled differently based
|
||||
on whether they are default or user-set values.
|
||||
"""
|
||||
HTML_TEMPLATE = """
|
||||
<div class="estimator-table">
|
||||
<details>
|
||||
<summary>Parameters</summary>
|
||||
<table class="parameters-table">
|
||||
<tbody>
|
||||
{rows}
|
||||
</tbody>
|
||||
</table>
|
||||
</details>
|
||||
</div>
|
||||
"""
|
||||
ROW_TEMPLATE = """
|
||||
<tr class="{param_type}">
|
||||
<td><i class="copy-paste-icon"
|
||||
onclick="copyToClipboard('{param_name}',
|
||||
this.parentElement.nextElementSibling)"
|
||||
></i></td>
|
||||
<td class="param">{param_name} </td>
|
||||
<td class="value">{param_value}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
rows = [
|
||||
ROW_TEMPLATE.format(**_read_params(name, value, params.non_default))
|
||||
for name, value in params.items()
|
||||
]
|
||||
|
||||
return HTML_TEMPLATE.format(rows="\n".join(rows))
|
||||
|
||||
|
||||
class ParamsDict(ReprHTMLMixin, UserDict):
|
||||
"""Dictionary-like class to store and provide an HTML representation.
|
||||
|
||||
It builds an HTML structure to be used with Jupyter notebooks or similar
|
||||
environments. It allows storing metadata to track non-default parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict, default=None
|
||||
The original dictionary of parameters and their values.
|
||||
|
||||
non_default : tuple
|
||||
The list of non-default parameters.
|
||||
"""
|
||||
|
||||
_html_repr = _params_html_repr
|
||||
|
||||
def __init__(self, params=None, non_default=tuple()):
|
||||
super().__init__(params or {})
|
||||
self.non_default = non_default
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,616 @@
|
||||
import html
|
||||
import locale
|
||||
import re
|
||||
import types
|
||||
from contextlib import closing
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn import config_context
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.cluster import AgglomerativeClustering, Birch
|
||||
from sklearn.compose import ColumnTransformer, make_column_transformer
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.decomposition import PCA, TruncatedSVD
|
||||
from sklearn.ensemble import StackingClassifier, StackingRegressor, VotingClassifier
|
||||
from sklearn.feature_selection import SelectPercentile
|
||||
from sklearn.gaussian_process.kernels import ExpSineSquared
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.kernel_ridge import KernelRidge
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from sklearn.multiclass import OneVsOneClassifier
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.pipeline import FeatureUnion, Pipeline, make_pipeline
|
||||
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, StandardScaler
|
||||
from sklearn.svm import LinearSVC, LinearSVR
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.utils._repr_html.base import _HTMLDocumentationLinkMixin
|
||||
from sklearn.utils._repr_html.estimator import (
|
||||
_get_css_style,
|
||||
_get_visual_block,
|
||||
_write_label_html,
|
||||
estimator_html_repr,
|
||||
)
|
||||
from sklearn.utils.fixes import parse_version
|
||||
|
||||
|
||||
def dummy_function(x, y):
|
||||
return x + y # pragma: nocover
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checked", [True, False])
|
||||
def test_write_label_html(checked):
|
||||
# Test checking logic and labeling
|
||||
name = "LogisticRegression"
|
||||
params = ""
|
||||
tool_tip = "hello-world"
|
||||
|
||||
with closing(StringIO()) as out:
|
||||
_write_label_html(out, params, name, tool_tip, checked=checked)
|
||||
html_label = out.getvalue()
|
||||
|
||||
p = (
|
||||
r'<label for="sk-estimator-id-[0-9]*"'
|
||||
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
|
||||
r"<div><div>LogisticRegression</div></div>"
|
||||
)
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_label)
|
||||
assert html_label.startswith('<div class="sk-label-container">')
|
||||
assert "<pre>hello-world</pre>" in html_label
|
||||
|
||||
if checked:
|
||||
assert "checked>" in html_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize("est", ["passthrough", "drop", None])
|
||||
def test_get_visual_block_single_str_none(est):
|
||||
# Test estimators that are represented by strings
|
||||
est_html_info = _get_visual_block(est)
|
||||
assert est_html_info.kind == "single"
|
||||
assert est_html_info.estimators == est
|
||||
assert est_html_info.names == str(est)
|
||||
assert est_html_info.name_details == str(est)
|
||||
|
||||
|
||||
def test_get_visual_block_single_estimator():
|
||||
est = LogisticRegression(C=10.0)
|
||||
est_html_info = _get_visual_block(est)
|
||||
assert est_html_info.kind == "single"
|
||||
assert est_html_info.estimators == est
|
||||
assert est_html_info.names == est.__class__.__name__
|
||||
assert est_html_info.name_details == str(est)
|
||||
|
||||
|
||||
def test_get_visual_block_pipeline():
|
||||
pipe = Pipeline(
|
||||
[
|
||||
("imputer", SimpleImputer()),
|
||||
("do_nothing", "passthrough"),
|
||||
("do_nothing_more", None),
|
||||
("classifier", LogisticRegression()),
|
||||
]
|
||||
)
|
||||
est_html_info = _get_visual_block(pipe)
|
||||
assert est_html_info.kind == "serial"
|
||||
assert est_html_info.estimators == tuple(step[1] for step in pipe.steps)
|
||||
assert est_html_info.names == [
|
||||
"imputer: SimpleImputer",
|
||||
"do_nothing: passthrough",
|
||||
"do_nothing_more: passthrough",
|
||||
"classifier: LogisticRegression",
|
||||
]
|
||||
assert est_html_info.name_details == [str(est) for _, est in pipe.steps]
|
||||
|
||||
|
||||
def test_get_visual_block_feature_union():
|
||||
f_union = FeatureUnion([("pca", PCA()), ("svd", TruncatedSVD())])
|
||||
est_html_info = _get_visual_block(f_union)
|
||||
assert est_html_info.kind == "parallel"
|
||||
assert est_html_info.names == ("pca", "svd")
|
||||
assert est_html_info.estimators == tuple(
|
||||
trans[1] for trans in f_union.transformer_list
|
||||
)
|
||||
assert est_html_info.name_details == (None, None)
|
||||
|
||||
|
||||
def test_get_visual_block_voting():
|
||||
clf = VotingClassifier(
|
||||
[("log_reg", LogisticRegression()), ("mlp", MLPClassifier())]
|
||||
)
|
||||
est_html_info = _get_visual_block(clf)
|
||||
assert est_html_info.kind == "parallel"
|
||||
assert est_html_info.estimators == tuple(trans[1] for trans in clf.estimators)
|
||||
assert est_html_info.names == ("log_reg", "mlp")
|
||||
assert est_html_info.name_details == (None, None)
|
||||
|
||||
|
||||
def test_get_visual_block_column_transformer():
|
||||
ct = ColumnTransformer(
|
||||
[("pca", PCA(), ["num1", "num2"]), ("svd", TruncatedSVD, [0, 3])]
|
||||
)
|
||||
est_html_info = _get_visual_block(ct)
|
||||
assert est_html_info.kind == "parallel"
|
||||
assert est_html_info.estimators == tuple(trans[1] for trans in ct.transformers)
|
||||
assert est_html_info.names == ("pca", "svd")
|
||||
assert est_html_info.name_details == (["num1", "num2"], [0, 3])
|
||||
|
||||
|
||||
def test_estimator_html_repr_an_empty_pipeline():
|
||||
"""Check that the representation of an empty Pipeline does not fail.
|
||||
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/issues/30197
|
||||
"""
|
||||
empty_pipeline = Pipeline([])
|
||||
estimator_html_repr(empty_pipeline)
|
||||
|
||||
|
||||
def test_estimator_html_repr_pipeline():
|
||||
num_trans = Pipeline(
|
||||
steps=[("pass", "passthrough"), ("imputer", SimpleImputer(strategy="median"))]
|
||||
)
|
||||
|
||||
cat_trans = Pipeline(
|
||||
steps=[
|
||||
("imputer", SimpleImputer(strategy="constant", missing_values="empty")),
|
||||
("one-hot", OneHotEncoder(drop="first")),
|
||||
]
|
||||
)
|
||||
|
||||
preprocess = ColumnTransformer(
|
||||
[
|
||||
("num", num_trans, ["a", "b", "c", "d", "e"]),
|
||||
("cat", cat_trans, [0, 1, 2, 3]),
|
||||
]
|
||||
)
|
||||
|
||||
feat_u = FeatureUnion(
|
||||
[
|
||||
("pca", PCA(n_components=1)),
|
||||
(
|
||||
"tsvd",
|
||||
Pipeline(
|
||||
[
|
||||
("first", TruncatedSVD(n_components=3)),
|
||||
("select", SelectPercentile()),
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
clf = VotingClassifier(
|
||||
[
|
||||
("lr", LogisticRegression(solver="lbfgs", random_state=1)),
|
||||
("mlp", MLPClassifier(alpha=0.001)),
|
||||
]
|
||||
)
|
||||
|
||||
pipe = Pipeline(
|
||||
[("preprocessor", preprocess), ("feat_u", feat_u), ("classifier", clf)]
|
||||
)
|
||||
html_output = estimator_html_repr(pipe)
|
||||
|
||||
# top level estimators show estimator with changes
|
||||
assert html.escape(str(pipe)) in html_output
|
||||
for _, est in pipe.steps:
|
||||
assert html.escape(str(est))[:44] in html_output
|
||||
|
||||
# low level estimators do not show changes
|
||||
with config_context(print_changed_only=True):
|
||||
assert html.escape(str(num_trans["pass"])) in html_output
|
||||
assert "<div><div>passthrough</div></div></label>" in html_output
|
||||
assert html.escape(str(num_trans["imputer"])) in html_output
|
||||
|
||||
for _, _, cols in preprocess.transformers:
|
||||
assert f"<pre>{html.escape(str(cols))}</pre>" in html_output
|
||||
|
||||
# feature union
|
||||
for name, _ in feat_u.transformer_list:
|
||||
assert f"<label>{html.escape(name)}</label>" in html_output
|
||||
|
||||
pca = feat_u.transformer_list[0][1]
|
||||
|
||||
assert html.escape(str(pca)) in html_output
|
||||
|
||||
tsvd = feat_u.transformer_list[1][1]
|
||||
first = tsvd["first"]
|
||||
select = tsvd["select"]
|
||||
assert html.escape(str(first)) in html_output
|
||||
assert html.escape(str(select)) in html_output
|
||||
|
||||
# voting classifier
|
||||
for name, est in clf.estimators:
|
||||
assert html.escape(name) in html_output
|
||||
assert html.escape(str(est)) in html_output
|
||||
|
||||
# verify that prefers-color-scheme is implemented
|
||||
assert "prefers-color-scheme" in html_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("final_estimator", [None, LinearSVC()])
|
||||
def test_stacking_classifier(final_estimator):
|
||||
estimators = [
|
||||
("mlp", MLPClassifier(alpha=0.001)),
|
||||
("tree", DecisionTreeClassifier()),
|
||||
]
|
||||
clf = StackingClassifier(estimators=estimators, final_estimator=final_estimator)
|
||||
|
||||
html_output = estimator_html_repr(clf)
|
||||
|
||||
assert html.escape(str(clf)) in html_output
|
||||
# If final_estimator's default changes from LogisticRegression
|
||||
# this should be updated
|
||||
if final_estimator is None:
|
||||
assert "LogisticRegression" in html_output
|
||||
else:
|
||||
assert final_estimator.__class__.__name__ in html_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("final_estimator", [None, LinearSVR()])
|
||||
def test_stacking_regressor(final_estimator):
|
||||
reg = StackingRegressor(
|
||||
estimators=[("svr", LinearSVR())], final_estimator=final_estimator
|
||||
)
|
||||
html_output = estimator_html_repr(reg)
|
||||
|
||||
assert html.escape(str(reg.estimators[0][0])) in html_output
|
||||
p = (
|
||||
r'<label for="sk-estimator-id-[0-9]*"'
|
||||
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
|
||||
r"<div><div>LinearSVR</div></div>"
|
||||
)
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_output)
|
||||
|
||||
if final_estimator is None:
|
||||
p = (
|
||||
r'<label for="sk-estimator-id-[0-9]*"'
|
||||
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
|
||||
r"<div><div>RidgeCV</div></div>"
|
||||
)
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_output)
|
||||
else:
|
||||
assert html.escape(final_estimator.__class__.__name__) in html_output
|
||||
|
||||
|
||||
def test_birch_duck_typing_meta():
|
||||
# Test duck typing meta estimators with Birch
|
||||
birch = Birch(n_clusters=AgglomerativeClustering(n_clusters=3))
|
||||
html_output = estimator_html_repr(birch)
|
||||
|
||||
# inner estimators do not show changes
|
||||
with config_context(print_changed_only=True):
|
||||
assert f"<pre>{html.escape(str(birch.n_clusters))}" in html_output
|
||||
|
||||
p = r"<div><div>AgglomerativeClustering</div></div><div>.+</div></label>"
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_output)
|
||||
|
||||
# outer estimator contains all changes
|
||||
assert f"<pre>{html.escape(str(birch))}" in html_output
|
||||
|
||||
|
||||
def test_ovo_classifier_duck_typing_meta():
|
||||
# Test duck typing metaestimators with OVO
|
||||
ovo = OneVsOneClassifier(LinearSVC(penalty="l1"))
|
||||
html_output = estimator_html_repr(ovo)
|
||||
|
||||
# inner estimators do not show changes
|
||||
with config_context(print_changed_only=True):
|
||||
assert f"<pre>{html.escape(str(ovo.estimator))}" in html_output
|
||||
# regex to match the start of the tag
|
||||
p = (
|
||||
r'<label for="sk-estimator-id-[0-9]*" '
|
||||
r'class="sk-toggleable__label sk-toggleable__label-arrow">'
|
||||
r"<div><div>LinearSVC</div></div>"
|
||||
)
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_output)
|
||||
|
||||
# outer estimator
|
||||
assert f"<pre>{html.escape(str(ovo))}" in html_output
|
||||
|
||||
|
||||
def test_duck_typing_nested_estimator():
|
||||
# Test duck typing metaestimators with random search
|
||||
kernel_ridge = KernelRidge(kernel=ExpSineSquared())
|
||||
param_distributions = {"alpha": [1, 2]}
|
||||
|
||||
kernel_ridge_tuned = RandomizedSearchCV(
|
||||
kernel_ridge,
|
||||
param_distributions=param_distributions,
|
||||
)
|
||||
html_output = estimator_html_repr(kernel_ridge_tuned)
|
||||
assert "<div><div>estimator: KernelRidge</div></div></label>" in html_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("print_changed_only", [True, False])
|
||||
def test_one_estimator_print_change_only(print_changed_only):
|
||||
pca = PCA(n_components=10)
|
||||
|
||||
with config_context(print_changed_only=print_changed_only):
|
||||
pca_repr = html.escape(str(pca))
|
||||
html_output = estimator_html_repr(pca)
|
||||
assert pca_repr in html_output
|
||||
|
||||
|
||||
def test_fallback_exists():
|
||||
"""Check that repr fallback is in the HTML."""
|
||||
pca = PCA(n_components=10)
|
||||
html_output = estimator_html_repr(pca)
|
||||
|
||||
assert (
|
||||
f'<div class="sk-text-repr-fallback"><pre>{html.escape(str(pca))}'
|
||||
in html_output
|
||||
)
|
||||
|
||||
|
||||
def test_show_arrow_pipeline():
|
||||
"""Show arrow in pipeline for top level in pipeline"""
|
||||
pipe = Pipeline([("scale", StandardScaler()), ("log_Reg", LogisticRegression())])
|
||||
|
||||
html_output = estimator_html_repr(pipe)
|
||||
assert (
|
||||
'class="sk-toggleable__label sk-toggleable__label-arrow">'
|
||||
"<div><div>Pipeline</div></div>" in html_output
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_parameters_in_stacking():
|
||||
"""Invalidate stacking configuration uses default repr.
|
||||
|
||||
Non-regression test for #24009.
|
||||
"""
|
||||
stacker = StackingClassifier(estimators=[])
|
||||
|
||||
html_output = estimator_html_repr(stacker)
|
||||
assert html.escape(str(stacker)) in html_output
|
||||
|
||||
|
||||
def test_estimator_get_params_return_cls():
|
||||
"""Check HTML repr works where a value in get_params is a class."""
|
||||
|
||||
class MyEstimator:
|
||||
def get_params(self, deep=False):
|
||||
return {"inner_cls": LogisticRegression}
|
||||
|
||||
est = MyEstimator()
|
||||
assert "MyEstimator" in estimator_html_repr(est)
|
||||
|
||||
|
||||
def test_estimator_html_repr_unfitted_vs_fitted():
|
||||
"""Check that we have the information that the estimator is fitted or not in the
|
||||
HTML representation.
|
||||
"""
|
||||
|
||||
class MyEstimator(BaseEstimator):
|
||||
def fit(self, X, y):
|
||||
self.fitted_ = True
|
||||
return self
|
||||
|
||||
X, y = load_iris(return_X_y=True)
|
||||
estimator = MyEstimator()
|
||||
assert "<span>Not fitted</span>" in estimator_html_repr(estimator)
|
||||
estimator.fit(X, y)
|
||||
assert "<span>Fitted</span>" in estimator_html_repr(estimator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"estimator",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), slice(0, 3))),
|
||||
LogisticRegression(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_estimator_html_repr_fitted_icon(estimator):
|
||||
"""Check that we are showing the fitted status icon only once."""
|
||||
pattern = '<span class="sk-estimator-doc-link ">i<span>Not fitted</span></span>'
|
||||
assert estimator_html_repr(estimator).count(pattern) == 1
|
||||
X, y = load_iris(return_X_y=True)
|
||||
estimator.fit(X, y)
|
||||
pattern = '<span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span>'
|
||||
assert estimator_html_repr(estimator).count(pattern) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mock_version", ["1.3.0.dev0", "1.3.0"])
|
||||
def test_html_documentation_link_mixin_sklearn(mock_version):
|
||||
"""Check the behaviour of the `_HTMLDocumentationLinkMixin` class for scikit-learn
|
||||
default.
|
||||
"""
|
||||
|
||||
# mock the `__version__` where the mixin is located
|
||||
with patch("sklearn.utils._repr_html.base.__version__", mock_version):
|
||||
mixin = _HTMLDocumentationLinkMixin()
|
||||
|
||||
assert mixin._doc_link_module == "sklearn"
|
||||
sklearn_version = parse_version(mock_version)
|
||||
# we need to parse the version manually to be sure that this test is passing in
|
||||
# other branches than `main` (that is "dev").
|
||||
if sklearn_version.dev is None:
|
||||
version = f"{sklearn_version.major}.{sklearn_version.minor}"
|
||||
else:
|
||||
version = "dev"
|
||||
assert (
|
||||
mixin._doc_link_template
|
||||
== f"https://scikit-learn.org/{version}/modules/generated/"
|
||||
"{estimator_module}.{estimator_name}.html"
|
||||
)
|
||||
assert (
|
||||
mixin._get_doc_link()
|
||||
== f"https://scikit-learn.org/{version}/modules/generated/"
|
||||
"sklearn.utils._HTMLDocumentationLinkMixin.html"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_path,expected_module",
|
||||
[
|
||||
("prefix.mymodule", "prefix.mymodule"),
|
||||
("prefix._mymodule", "prefix"),
|
||||
("prefix.mypackage._mymodule", "prefix.mypackage"),
|
||||
("prefix.mypackage._mymodule.submodule", "prefix.mypackage"),
|
||||
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
|
||||
],
|
||||
)
|
||||
def test_html_documentation_link_mixin_get_doc_link_instance(
|
||||
module_path, expected_module
|
||||
):
|
||||
"""Check the behaviour of the `_get_doc_link` with various parameter."""
|
||||
|
||||
class FooBar(_HTMLDocumentationLinkMixin):
|
||||
pass
|
||||
|
||||
FooBar.__module__ = module_path
|
||||
est = FooBar()
|
||||
# if we set `_doc_link`, then we expect to infer a module and name for the estimator
|
||||
est._doc_link_module = "prefix"
|
||||
est._doc_link_template = (
|
||||
"https://website.com/{estimator_module}.{estimator_name}.html"
|
||||
)
|
||||
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_path,expected_module",
|
||||
[
|
||||
("prefix.mymodule", "prefix.mymodule"),
|
||||
("prefix._mymodule", "prefix"),
|
||||
("prefix.mypackage._mymodule", "prefix.mypackage"),
|
||||
("prefix.mypackage._mymodule.submodule", "prefix.mypackage"),
|
||||
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
|
||||
],
|
||||
)
|
||||
def test_html_documentation_link_mixin_get_doc_link_class(module_path, expected_module):
|
||||
"""Check the behaviour of the `_get_doc_link` when `_doc_link_module` and
|
||||
`_doc_link_template` are defined at the class level and not at the instance
|
||||
level."""
|
||||
|
||||
class FooBar(_HTMLDocumentationLinkMixin):
|
||||
_doc_link_module = "prefix"
|
||||
_doc_link_template = (
|
||||
"https://website.com/{estimator_module}.{estimator_name}.html"
|
||||
)
|
||||
|
||||
FooBar.__module__ = module_path
|
||||
est = FooBar()
|
||||
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
|
||||
|
||||
|
||||
def test_html_documentation_link_mixin_get_doc_link_out_of_library():
|
||||
"""Check the behaviour of the `_get_doc_link` with various parameter."""
|
||||
mixin = _HTMLDocumentationLinkMixin()
|
||||
|
||||
# if the `_doc_link_module` does not refer to the root module of the estimator
|
||||
# (here the mixin), then we should return an empty string.
|
||||
mixin._doc_link_module = "xxx"
|
||||
assert mixin._get_doc_link() == ""
|
||||
|
||||
|
||||
def test_html_documentation_link_mixin_doc_link_url_param_generator_instance():
|
||||
mixin = _HTMLDocumentationLinkMixin()
|
||||
# we can bypass the generation by providing our own callable
|
||||
mixin._doc_link_template = (
|
||||
"https://website.com/{my_own_variable}.{another_variable}.html"
|
||||
)
|
||||
|
||||
def url_param_generator(estimator):
|
||||
return {
|
||||
"my_own_variable": "value_1",
|
||||
"another_variable": "value_2",
|
||||
}
|
||||
|
||||
mixin._doc_link_url_param_generator = types.MethodType(url_param_generator, mixin)
|
||||
|
||||
assert mixin._get_doc_link() == "https://website.com/value_1.value_2.html"
|
||||
|
||||
|
||||
def test_html_documentation_link_mixin_doc_link_url_param_generator_class():
|
||||
# we can bypass the generation by providing our own callable
|
||||
|
||||
def url_param_generator(estimator):
|
||||
return {
|
||||
"my_own_variable": "value_1",
|
||||
"another_variable": "value_2",
|
||||
}
|
||||
|
||||
class FooBar(_HTMLDocumentationLinkMixin):
|
||||
_doc_link_template = (
|
||||
"https://website.com/{my_own_variable}.{another_variable}.html"
|
||||
)
|
||||
_doc_link_url_param_generator = url_param_generator
|
||||
|
||||
estimator = FooBar()
|
||||
assert estimator._get_doc_link() == "https://website.com/value_1.value_2.html"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_non_utf8_locale():
|
||||
"""Pytest fixture to set non utf-8 locale during the test.
|
||||
|
||||
The locale is set to the original one after the test has run.
|
||||
"""
|
||||
try:
|
||||
locale.setlocale(locale.LC_CTYPE, "C")
|
||||
except locale.Error:
|
||||
pytest.skip("'C' locale is not available on this OS")
|
||||
|
||||
yield
|
||||
|
||||
# Resets the locale to the original one. Python calls setlocale(LC_TYPE, "")
|
||||
# at startup according to
|
||||
# https://docs.python.org/3/library/locale.html#background-details-hints-tips-and-caveats.
|
||||
# This assumes that no other locale changes have been made. For some reason,
|
||||
# on some platforms, trying to restore locale with something like
|
||||
# locale.setlocale(locale.LC_CTYPE, locale.getlocale()) raises a
|
||||
# locale.Error: unsupported locale setting
|
||||
locale.setlocale(locale.LC_CTYPE, "")
|
||||
|
||||
|
||||
def test_non_utf8_locale(set_non_utf8_locale):
|
||||
"""Checks that utf8 encoding is used when reading the CSS file.
|
||||
|
||||
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/27725
|
||||
"""
|
||||
_get_css_style()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, expected_name",
|
||||
[
|
||||
(lambda x: x + 1, html.escape("<lambda>")),
|
||||
(dummy_function, "dummy_function"),
|
||||
(partial(dummy_function, y=1), "dummy_function"),
|
||||
(np.vectorize(partial(dummy_function, y=1)), re.escape("vectorize(...)")),
|
||||
],
|
||||
)
|
||||
def test_function_transformer_show_caption(func, expected_name):
|
||||
# Test that function name is shown as the name and "FunctionTransformer" is shown
|
||||
# in the caption
|
||||
ft = FunctionTransformer(func)
|
||||
html_output = estimator_html_repr(ft)
|
||||
|
||||
p = (
|
||||
r'<label for="sk-estimator-id-[0-9]*" class="sk-toggleable__label fitted '
|
||||
rf'sk-toggleable__label-arrow"><div><div>{expected_name}</div>'
|
||||
r'<div class="caption">FunctionTransformer</div></div>'
|
||||
)
|
||||
re_compiled = re.compile(p)
|
||||
assert re_compiled.search(html_output)
|
||||
|
||||
|
||||
def test_estimator_html_repr_table():
|
||||
"""Check that we add the table of parameters in the HTML representation."""
|
||||
est = LogisticRegression(C=10.0, fit_intercept=False)
|
||||
assert "parameters-table" in estimator_html_repr(est)
|
||||
@@ -0,0 +1,74 @@
|
||||
import pytest
|
||||
|
||||
from sklearn import config_context
|
||||
from sklearn.utils._repr_html.params import ParamsDict, _params_html_repr, _read_params
|
||||
|
||||
|
||||
def test_params_dict_content():
|
||||
"""Check the behavior of the ParamsDict class."""
|
||||
params = ParamsDict({"a": 1, "b": 2})
|
||||
assert params["a"] == 1
|
||||
assert params["b"] == 2
|
||||
assert params.non_default == ()
|
||||
|
||||
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
|
||||
assert params["a"] == 1
|
||||
assert params["b"] == 2
|
||||
assert params.non_default == ("a",)
|
||||
|
||||
|
||||
def test_params_dict_repr_html_():
|
||||
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
|
||||
out = params._repr_html_()
|
||||
assert "<summary>Parameters</summary>" in out
|
||||
|
||||
with config_context(display="text"):
|
||||
msg = "_repr_html_ is only defined when"
|
||||
with pytest.raises(AttributeError, match=msg):
|
||||
params._repr_html_()
|
||||
|
||||
|
||||
def test_params_dict_repr_mimebundle():
|
||||
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
|
||||
out = params._repr_mimebundle_()
|
||||
|
||||
assert "text/plain" in out
|
||||
assert "text/html" in out
|
||||
assert "<summary>Parameters</summary>" in out["text/html"]
|
||||
assert out["text/plain"] == "{'a': 1, 'b': 2}"
|
||||
|
||||
with config_context(display="text"):
|
||||
out = params._repr_mimebundle_()
|
||||
assert "text/plain" in out
|
||||
assert "text/html" not in out
|
||||
|
||||
|
||||
def test_read_params():
|
||||
"""Check the behavior of the `_read_params` function."""
|
||||
out = _read_params("a", 1, tuple())
|
||||
assert out["param_type"] == "default"
|
||||
assert out["param_name"] == "a"
|
||||
assert out["param_value"] == "1"
|
||||
|
||||
# check non-default parameters
|
||||
out = _read_params("a", 1, ("a",))
|
||||
assert out["param_type"] == "user-set"
|
||||
assert out["param_name"] == "a"
|
||||
assert out["param_value"] == "1"
|
||||
|
||||
# check that we escape html tags
|
||||
tag_injection = "<script>alert('xss')</script>"
|
||||
out = _read_params("a", tag_injection, tuple())
|
||||
assert (
|
||||
out["param_value"]
|
||||
== ""<script>alert('xss')</script>""
|
||||
)
|
||||
assert out["param_name"] == "a"
|
||||
assert out["param_type"] == "default"
|
||||
|
||||
|
||||
def test_params_html_repr():
|
||||
"""Check returned HTML template"""
|
||||
params = ParamsDict({"a": 1, "b": 2})
|
||||
assert "parameters-table" in _params_html_repr(params)
|
||||
assert "estimator-table" in _params_html_repr(params)
|
||||
317
venv/lib/python3.12/site-packages/sklearn/utils/_response.py
Normal file
317
venv/lib/python3.12/site-packages/sklearn/utils/_response.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Utilities to get the response values of a classifier or a regressor.
|
||||
|
||||
It allows to make uniform checks and validation.
|
||||
"""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..base import is_classifier
|
||||
from .multiclass import type_of_target
|
||||
from .validation import _check_response_method, check_is_fitted
|
||||
|
||||
|
||||
def _process_predict_proba(*, y_pred, target_type, classes, pos_label):
|
||||
"""Get the response values when the response method is `predict_proba`.
|
||||
|
||||
This function process the `y_pred` array in the binary and multi-label cases.
|
||||
In the binary case, it selects the column corresponding to the positive
|
||||
class. In the multi-label case, it stacks the predictions if they are not
|
||||
in the "compressed" format `(n_samples, n_outputs)`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_pred : ndarray
|
||||
Output of `estimator.predict_proba`. The shape depends on the target type:
|
||||
|
||||
- for binary classification, it is a 2d array of shape `(n_samples, 2)`;
|
||||
- for multiclass classification, it is a 2d array of shape
|
||||
`(n_samples, n_classes)`;
|
||||
- for multilabel classification, it is either a list of 2d arrays of shape
|
||||
`(n_samples, 2)` (e.g. `RandomForestClassifier` or `KNeighborsClassifier`) or
|
||||
an array of shape `(n_samples, n_outputs)` (e.g. `MLPClassifier` or
|
||||
`RidgeClassifier`).
|
||||
|
||||
target_type : {"binary", "multiclass", "multilabel-indicator"}
|
||||
Type of the target.
|
||||
|
||||
classes : ndarray of shape (n_classes,) or list of such arrays
|
||||
Class labels as reported by `estimator.classes_`.
|
||||
|
||||
pos_label : int, float, bool or str
|
||||
Only used with binary and multiclass targets.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
|
||||
(n_samples, n_output)
|
||||
Compressed predictions format as requested by the metrics.
|
||||
"""
|
||||
if target_type == "binary" and y_pred.shape[1] < 2:
|
||||
# We don't handle classifiers trained on a single class.
|
||||
raise ValueError(
|
||||
f"Got predict_proba of shape {y_pred.shape}, but need "
|
||||
"classifier with two classes."
|
||||
)
|
||||
|
||||
if target_type == "binary":
|
||||
col_idx = np.flatnonzero(classes == pos_label)[0]
|
||||
return y_pred[:, col_idx]
|
||||
elif target_type == "multilabel-indicator":
|
||||
# Use a compress format of shape `(n_samples, n_output)`.
|
||||
# Only `MLPClassifier` and `RidgeClassifier` return an array of shape
|
||||
# `(n_samples, n_outputs)`.
|
||||
if isinstance(y_pred, list):
|
||||
# list of arrays of shape `(n_samples, 2)`
|
||||
return np.vstack([p[:, -1] for p in y_pred]).T
|
||||
else:
|
||||
# array of shape `(n_samples, n_outputs)`
|
||||
return y_pred
|
||||
|
||||
return y_pred
|
||||
|
||||
|
||||
def _process_decision_function(*, y_pred, target_type, classes, pos_label):
|
||||
"""Get the response values when the response method is `decision_function`.
|
||||
|
||||
This function process the `y_pred` array in the binary and multi-label cases.
|
||||
In the binary case, it inverts the sign of the score if the positive label
|
||||
is not `classes[1]`. In the multi-label case, it stacks the predictions if
|
||||
they are not in the "compressed" format `(n_samples, n_outputs)`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_pred : ndarray
|
||||
Output of `estimator.decision_function`. The shape depends on the target type:
|
||||
|
||||
- for binary classification, it is a 1d array of shape `(n_samples,)` where the
|
||||
sign is assuming that `classes[1]` is the positive class;
|
||||
- for multiclass classification, it is a 2d array of shape
|
||||
`(n_samples, n_classes)`;
|
||||
- for multilabel classification, it is a 2d array of shape `(n_samples,
|
||||
n_outputs)`.
|
||||
|
||||
target_type : {"binary", "multiclass", "multilabel-indicator"}
|
||||
Type of the target.
|
||||
|
||||
classes : ndarray of shape (n_classes,) or list of such arrays
|
||||
Class labels as reported by `estimator.classes_`.
|
||||
|
||||
pos_label : int, float, bool or str
|
||||
Only used with binary and multiclass targets.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
|
||||
(n_samples, n_output)
|
||||
Compressed predictions format as requested by the metrics.
|
||||
"""
|
||||
if target_type == "binary" and pos_label == classes[0]:
|
||||
return -1 * y_pred
|
||||
return y_pred
|
||||
|
||||
|
||||
def _get_response_values(
|
||||
estimator,
|
||||
X,
|
||||
response_method,
|
||||
pos_label=None,
|
||||
return_response_method_used=False,
|
||||
):
|
||||
"""Compute the response values of a classifier, an outlier detector, or a regressor.
|
||||
|
||||
The response values are predictions such that it follows the following shape:
|
||||
|
||||
- for binary classification, it is a 1d array of shape `(n_samples,)`;
|
||||
- for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
|
||||
- for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
|
||||
- for outlier detection, it is a 1d array of shape `(n_samples,)`;
|
||||
- for regression, it is a 1d array of shape `(n_samples,)`.
|
||||
|
||||
If `estimator` is a binary classifier, also return the label for the
|
||||
effective positive class.
|
||||
|
||||
This utility is used primarily in the displays and the scikit-learn scorers.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier, outlier detector, or regressor or a
|
||||
fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a
|
||||
classifier, an outlier detector, or a regressor.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
response_method : {"predict_proba", "predict_log_proba", "decision_function", \
|
||||
"predict"} or list of such str
|
||||
Specifies the response method to use get prediction from an estimator
|
||||
(i.e. :term:`predict_proba`, :term:`predict_log_proba`,
|
||||
:term:`decision_function` or :term:`predict`). Possible choices are:
|
||||
|
||||
- if `str`, it corresponds to the name to the method to return;
|
||||
- if a list of `str`, it provides the method names in order of
|
||||
preference. The method returned corresponds to the first method in
|
||||
the list and which is implemented by `estimator`.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing
|
||||
the metrics. If `None` and target is 'binary', `estimators.classes_[1]` is
|
||||
considered as the positive class.
|
||||
|
||||
return_response_method_used : bool, default=False
|
||||
Whether to return the response method used to compute the response
|
||||
values.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
|
||||
(n_samples, n_outputs)
|
||||
Target scores calculated from the provided `response_method`
|
||||
and `pos_label`.
|
||||
|
||||
pos_label : int, float, bool, str or None
|
||||
The class considered as the positive class when computing
|
||||
the metrics. Returns `None` if `estimator` is a regressor or an outlier
|
||||
detector.
|
||||
|
||||
response_method_used : str
|
||||
The response method used to compute the response values. Only returned
|
||||
if `return_response_method_used` is `True`.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `pos_label` is not a valid label.
|
||||
If the shape of `y_pred` is not consistent for binary classifier.
|
||||
If the response method can be applied to a classifier only and
|
||||
`estimator` is a regressor.
|
||||
"""
|
||||
from sklearn.base import is_classifier, is_outlier_detector
|
||||
|
||||
if is_classifier(estimator):
|
||||
prediction_method = _check_response_method(estimator, response_method)
|
||||
classes = estimator.classes_
|
||||
target_type = type_of_target(classes)
|
||||
|
||||
if target_type in ("binary", "multiclass"):
|
||||
if pos_label is not None and pos_label not in classes.tolist():
|
||||
raise ValueError(
|
||||
f"pos_label={pos_label} is not a valid label: It should be "
|
||||
f"one of {classes}"
|
||||
)
|
||||
elif pos_label is None and target_type == "binary":
|
||||
pos_label = classes[-1]
|
||||
|
||||
y_pred = prediction_method(X)
|
||||
|
||||
if prediction_method.__name__ in ("predict_proba", "predict_log_proba"):
|
||||
y_pred = _process_predict_proba(
|
||||
y_pred=y_pred,
|
||||
target_type=target_type,
|
||||
classes=classes,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
elif prediction_method.__name__ == "decision_function":
|
||||
y_pred = _process_decision_function(
|
||||
y_pred=y_pred,
|
||||
target_type=target_type,
|
||||
classes=classes,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
elif is_outlier_detector(estimator):
|
||||
prediction_method = _check_response_method(estimator, response_method)
|
||||
y_pred, pos_label = prediction_method(X), None
|
||||
else: # estimator is a regressor
|
||||
if response_method != "predict":
|
||||
raise ValueError(
|
||||
f"{estimator.__class__.__name__} should either be a classifier to be "
|
||||
f"used with response_method={response_method} or the response_method "
|
||||
"should be 'predict'. Got a regressor with response_method="
|
||||
f"{response_method} instead."
|
||||
)
|
||||
prediction_method = estimator.predict
|
||||
y_pred, pos_label = prediction_method(X), None
|
||||
|
||||
if return_response_method_used:
|
||||
return y_pred, pos_label, prediction_method.__name__
|
||||
return y_pred, pos_label
|
||||
|
||||
|
||||
def _get_response_values_binary(
|
||||
estimator, X, response_method, pos_label=None, return_response_method_used=False
|
||||
):
|
||||
"""Compute the response values of a binary classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a binary classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
response_method : {'auto', 'predict_proba', 'decision_function'}
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
pos_label : int, float, bool or str, default=None
|
||||
The class considered as the positive class when computing
|
||||
the metrics. By default, `estimators.classes_[1]` is
|
||||
considered as the positive class.
|
||||
|
||||
return_response_method_used : bool, default=False
|
||||
Whether to return the response method used to compute the response
|
||||
values.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred : ndarray of shape (n_samples,)
|
||||
Target scores calculated from the provided response_method
|
||||
and pos_label.
|
||||
|
||||
pos_label : int, float, bool or str
|
||||
The class considered as the positive class when computing
|
||||
the metrics.
|
||||
|
||||
response_method_used : str
|
||||
The response method used to compute the response values. Only returned
|
||||
if `return_response_method_used` is `True`.
|
||||
|
||||
.. versionadded:: 1.5
|
||||
"""
|
||||
classification_error = "Expected 'estimator' to be a binary classifier."
|
||||
|
||||
check_is_fitted(estimator)
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(
|
||||
classification_error + f" Got {estimator.__class__.__name__} instead."
|
||||
)
|
||||
elif len(estimator.classes_) != 2:
|
||||
raise ValueError(
|
||||
classification_error + f" Got {len(estimator.classes_)} classes instead."
|
||||
)
|
||||
|
||||
if response_method == "auto":
|
||||
response_method = ["predict_proba", "decision_function"]
|
||||
|
||||
return _get_response_values(
|
||||
estimator,
|
||||
X,
|
||||
response_method,
|
||||
pos_label=pos_label,
|
||||
return_response_method_used=return_response_method_used,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,76 @@
|
||||
{{py:
|
||||
|
||||
"""
|
||||
Dataset abstractions for sequential data access.
|
||||
|
||||
Template file for easily generate fused types consistent code using Tempita
|
||||
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
|
||||
|
||||
Generated file: _seq_dataset.pxd
|
||||
|
||||
Each class is duplicated for all dtypes (float and double). The keywords
|
||||
between double braces are substituted during the build.
|
||||
"""
|
||||
|
||||
# name_suffix, c_type
|
||||
dtypes = [('64', 'float64_t'),
|
||||
('32', 'float32_t')]
|
||||
|
||||
}}
|
||||
"""Dataset abstractions for sequential data access."""
|
||||
|
||||
from ._typedefs cimport float32_t, float64_t, intp_t, uint32_t
|
||||
|
||||
# SequentialDataset and its two concrete subclasses are (optionally randomized)
|
||||
# iterators over the rows of a matrix X and corresponding target values y.
|
||||
|
||||
{{for name_suffix, c_type in dtypes}}
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
cdef class SequentialDataset{{name_suffix}}:
|
||||
cdef int current_index
|
||||
cdef int[::1] index
|
||||
cdef int *index_data_ptr
|
||||
cdef Py_ssize_t n_samples
|
||||
cdef uint32_t seed
|
||||
|
||||
cdef void shuffle(self, uint32_t seed) noexcept nogil
|
||||
cdef int _get_next_index(self) noexcept nogil
|
||||
cdef int _get_random_index(self) noexcept nogil
|
||||
|
||||
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
|
||||
int current_index) noexcept nogil
|
||||
cdef void next(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil
|
||||
cdef int random(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil
|
||||
|
||||
|
||||
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
|
||||
cdef const {{c_type}}[:, ::1] X
|
||||
cdef const {{c_type}}[::1] Y
|
||||
cdef const {{c_type}}[::1] sample_weights
|
||||
cdef Py_ssize_t n_features
|
||||
cdef intp_t X_stride
|
||||
cdef {{c_type}} *X_data_ptr
|
||||
cdef {{c_type}} *Y_data_ptr
|
||||
cdef const int[::1] feature_indices
|
||||
cdef int *feature_indices_ptr
|
||||
cdef {{c_type}} *sample_weight_data
|
||||
|
||||
|
||||
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
|
||||
cdef const {{c_type}}[::1] X_data
|
||||
cdef const int[::1] X_indptr
|
||||
cdef const int[::1] X_indices
|
||||
cdef const {{c_type}}[::1] Y
|
||||
cdef const {{c_type}}[::1] sample_weights
|
||||
cdef {{c_type}} *X_data_ptr
|
||||
cdef int *X_indptr_ptr
|
||||
cdef int *X_indices_ptr
|
||||
cdef {{c_type}} *Y_data_ptr
|
||||
cdef {{c_type}} *sample_weight_data
|
||||
|
||||
{{endfor}}
|
||||
@@ -0,0 +1,348 @@
|
||||
{{py:
|
||||
|
||||
"""
|
||||
Dataset abstractions for sequential data access.
|
||||
Template file for easily generate fused types consistent code using Tempita
|
||||
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
|
||||
|
||||
Generated file: _seq_dataset.pyx
|
||||
|
||||
Each class is duplicated for all dtypes (float and double). The keywords
|
||||
between double braces are substituted during the build.
|
||||
"""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# name_suffix, c_type, np_type
|
||||
dtypes = [('64', 'float64_t', 'np.float64'),
|
||||
('32', 'float32_t', 'np.float32')]
|
||||
|
||||
}}
|
||||
"""Dataset abstractions for sequential data access."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
from libc.limits cimport INT_MAX
|
||||
|
||||
from ._random cimport our_rand_r
|
||||
from ._typedefs cimport float32_t, float64_t, uint32_t
|
||||
|
||||
{{for name_suffix, c_type, np_type in dtypes}}
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
cdef class SequentialDataset{{name_suffix}}:
|
||||
"""Base class for datasets with sequential data access.
|
||||
|
||||
SequentialDataset is used to iterate over the rows of a matrix X and
|
||||
corresponding target values y, i.e. to iterate over samples.
|
||||
There are two methods to get the next sample:
|
||||
- next : Iterate sequentially (optionally randomized)
|
||||
- random : Iterate randomly (with replacement)
|
||||
|
||||
Attributes
|
||||
----------
|
||||
index : np.ndarray
|
||||
Index array for fast shuffling.
|
||||
|
||||
index_data_ptr : int
|
||||
Pointer to the index array.
|
||||
|
||||
current_index : int
|
||||
Index of current sample in ``index``.
|
||||
The index of current sample in the data is given by
|
||||
index_data_ptr[current_index].
|
||||
|
||||
n_samples : Py_ssize_t
|
||||
Number of samples in the dataset.
|
||||
|
||||
seed : uint32_t
|
||||
Seed used for random sampling. This attribute is modified at each call to the
|
||||
`random` method.
|
||||
"""
|
||||
|
||||
cdef void next(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil:
|
||||
"""Get the next example ``x`` from the dataset.
|
||||
|
||||
This method gets the next sample looping sequentially over all samples.
|
||||
The order can be shuffled with the method ``shuffle``.
|
||||
Shuffling once before iterating over all samples corresponds to a
|
||||
random draw without replacement. It is used for instance in SGD solver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_data_ptr : {{c_type}}**
|
||||
A pointer to the {{c_type}} array which holds the feature
|
||||
values of the next example.
|
||||
|
||||
x_ind_ptr : np.intc**
|
||||
A pointer to the int array which holds the feature
|
||||
indices of the next example.
|
||||
|
||||
nnz : int*
|
||||
A pointer to an int holding the number of non-zero
|
||||
values of the next example.
|
||||
|
||||
y : {{c_type}}*
|
||||
The target value of the next example.
|
||||
|
||||
sample_weight : {{c_type}}*
|
||||
The weight of the next example.
|
||||
"""
|
||||
cdef int current_index = self._get_next_index()
|
||||
self._sample(x_data_ptr, x_ind_ptr, nnz, y, sample_weight,
|
||||
current_index)
|
||||
|
||||
cdef int random(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil:
|
||||
"""Get a random example ``x`` from the dataset.
|
||||
|
||||
This method gets next sample chosen randomly over a uniform
|
||||
distribution. It corresponds to a random draw with replacement.
|
||||
It is used for instance in SAG solver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_data_ptr : {{c_type}}**
|
||||
A pointer to the {{c_type}} array which holds the feature
|
||||
values of the next example.
|
||||
|
||||
x_ind_ptr : np.intc**
|
||||
A pointer to the int array which holds the feature
|
||||
indices of the next example.
|
||||
|
||||
nnz : int*
|
||||
A pointer to an int holding the number of non-zero
|
||||
values of the next example.
|
||||
|
||||
y : {{c_type}}*
|
||||
The target value of the next example.
|
||||
|
||||
sample_weight : {{c_type}}*
|
||||
The weight of the next example.
|
||||
|
||||
Returns
|
||||
-------
|
||||
current_index : int
|
||||
Index of current sample.
|
||||
"""
|
||||
cdef int current_index = self._get_random_index()
|
||||
self._sample(x_data_ptr, x_ind_ptr, nnz, y, sample_weight,
|
||||
current_index)
|
||||
return current_index
|
||||
|
||||
cdef void shuffle(self, uint32_t seed) noexcept nogil:
|
||||
"""Permutes the ordering of examples."""
|
||||
# Fisher-Yates shuffle
|
||||
cdef int *ind = self.index_data_ptr
|
||||
cdef int n = self.n_samples
|
||||
cdef unsigned i, j
|
||||
for i in range(n - 1):
|
||||
j = i + our_rand_r(&seed) % (n - i)
|
||||
ind[i], ind[j] = ind[j], ind[i]
|
||||
|
||||
cdef int _get_next_index(self) noexcept nogil:
|
||||
cdef int current_index = self.current_index
|
||||
if current_index >= (self.n_samples - 1):
|
||||
current_index = -1
|
||||
|
||||
current_index += 1
|
||||
self.current_index = current_index
|
||||
return self.current_index
|
||||
|
||||
cdef int _get_random_index(self) noexcept nogil:
|
||||
cdef int n = self.n_samples
|
||||
cdef int current_index = our_rand_r(&self.seed) % n
|
||||
self.current_index = current_index
|
||||
return current_index
|
||||
|
||||
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
|
||||
int current_index) noexcept nogil:
|
||||
pass
|
||||
|
||||
def _shuffle_py(self, uint32_t seed):
|
||||
"""python function used for easy testing"""
|
||||
self.shuffle(seed)
|
||||
|
||||
def _next_py(self):
|
||||
"""python function used for easy testing"""
|
||||
cdef int current_index = self._get_next_index()
|
||||
return self._sample_py(current_index)
|
||||
|
||||
def _random_py(self):
|
||||
"""python function used for easy testing"""
|
||||
cdef int current_index = self._get_random_index()
|
||||
return self._sample_py(current_index)
|
||||
|
||||
def _sample_py(self, int current_index):
|
||||
"""python function used for easy testing"""
|
||||
cdef {{c_type}}* x_data_ptr
|
||||
cdef int* x_indices_ptr
|
||||
cdef int nnz, j
|
||||
cdef {{c_type}} y, sample_weight
|
||||
|
||||
# call _sample in cython
|
||||
self._sample(&x_data_ptr, &x_indices_ptr, &nnz, &y, &sample_weight,
|
||||
current_index)
|
||||
|
||||
# transform the pointed data in numpy CSR array
|
||||
cdef {{c_type}}[:] x_data = np.empty(nnz, dtype={{np_type}})
|
||||
cdef int[:] x_indices = np.empty(nnz, dtype=np.int32)
|
||||
cdef int[:] x_indptr = np.asarray([0, nnz], dtype=np.int32)
|
||||
|
||||
for j in range(nnz):
|
||||
x_data[j] = x_data_ptr[j]
|
||||
x_indices[j] = x_indices_ptr[j]
|
||||
|
||||
cdef int sample_idx = self.index_data_ptr[current_index]
|
||||
|
||||
return (
|
||||
(np.asarray(x_data), np.asarray(x_indices), np.asarray(x_indptr)),
|
||||
y,
|
||||
sample_weight,
|
||||
sample_idx,
|
||||
)
|
||||
|
||||
|
||||
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
|
||||
"""Dataset backed by a two-dimensional numpy array.
|
||||
|
||||
The dtype of the numpy array is expected to be ``{{np_type}}`` ({{c_type}})
|
||||
and C-style memory layout.
|
||||
"""
|
||||
|
||||
def __cinit__(
|
||||
self,
|
||||
const {{c_type}}[:, ::1] X,
|
||||
const {{c_type}}[::1] Y,
|
||||
const {{c_type}}[::1] sample_weights,
|
||||
uint32_t seed=1,
|
||||
):
|
||||
"""A ``SequentialDataset`` backed by a two-dimensional numpy array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, dtype={{c_type}}, ndim=2, mode='c'
|
||||
The sample array, of shape(n_samples, n_features)
|
||||
|
||||
Y : ndarray, dtype={{c_type}}, ndim=1, mode='c'
|
||||
The target array, of shape(n_samples, )
|
||||
|
||||
sample_weights : ndarray, dtype={{c_type}}, ndim=1, mode='c'
|
||||
The weight of each sample, of shape(n_samples,)
|
||||
"""
|
||||
if X.shape[0] > INT_MAX or X.shape[1] > INT_MAX:
|
||||
raise ValueError("More than %d samples or features not supported;"
|
||||
" got (%d, %d)."
|
||||
% (INT_MAX, X.shape[0], X.shape[1]))
|
||||
|
||||
# keep a reference to the data to prevent garbage collection
|
||||
self.X = X
|
||||
self.Y = Y
|
||||
self.sample_weights = sample_weights
|
||||
|
||||
self.n_samples = X.shape[0]
|
||||
self.n_features = X.shape[1]
|
||||
|
||||
self.feature_indices = np.arange(0, self.n_features, dtype=np.intc)
|
||||
self.feature_indices_ptr = <int *> &self.feature_indices[0]
|
||||
|
||||
self.current_index = -1
|
||||
self.X_stride = X.strides[0] // X.itemsize
|
||||
self.X_data_ptr = <{{c_type}} *> &X[0, 0]
|
||||
self.Y_data_ptr = <{{c_type}} *> &Y[0]
|
||||
self.sample_weight_data = <{{c_type}} *> &sample_weights[0]
|
||||
|
||||
# Use index array for fast shuffling
|
||||
self.index = np.arange(0, self.n_samples, dtype=np.intc)
|
||||
self.index_data_ptr = <int *> &self.index[0]
|
||||
# seed should not be 0 for our_rand_r
|
||||
self.seed = max(seed, 1)
|
||||
|
||||
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
|
||||
int current_index) noexcept nogil:
|
||||
cdef long long sample_idx = self.index_data_ptr[current_index]
|
||||
cdef long long offset = sample_idx * self.X_stride
|
||||
|
||||
y[0] = self.Y_data_ptr[sample_idx]
|
||||
x_data_ptr[0] = self.X_data_ptr + offset
|
||||
x_ind_ptr[0] = self.feature_indices_ptr
|
||||
nnz[0] = self.n_features
|
||||
sample_weight[0] = self.sample_weight_data[sample_idx]
|
||||
|
||||
|
||||
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
|
||||
"""A ``SequentialDataset`` backed by a scipy sparse CSR matrix. """
|
||||
|
||||
def __cinit__(
|
||||
self,
|
||||
const {{c_type}}[::1] X_data,
|
||||
const int[::1] X_indptr,
|
||||
const int[::1] X_indices,
|
||||
const {{c_type}}[::1] Y,
|
||||
const {{c_type}}[::1] sample_weights,
|
||||
uint32_t seed=1,
|
||||
):
|
||||
"""Dataset backed by a scipy sparse CSR matrix.
|
||||
|
||||
The feature indices of ``x`` are given by x_ind_ptr[0:nnz].
|
||||
The corresponding feature values are given by
|
||||
x_data_ptr[0:nnz].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X_data : ndarray, dtype={{c_type}}, ndim=1, mode='c'
|
||||
The data array of the CSR features matrix.
|
||||
|
||||
X_indptr : ndarray, dtype=np.intc, ndim=1, mode='c'
|
||||
The index pointer array of the CSR features matrix.
|
||||
|
||||
X_indices : ndarray, dtype=np.intc, ndim=1, mode='c'
|
||||
The column indices array of the CSR features matrix.
|
||||
|
||||
Y : ndarray, dtype={{c_type}}, ndim=1, mode='c'
|
||||
The target values.
|
||||
|
||||
sample_weights : ndarray, dtype={{c_type}}, ndim=1, mode='c'
|
||||
The weight of each sample.
|
||||
"""
|
||||
# keep a reference to the data to prevent garbage collection
|
||||
self.X_data = X_data
|
||||
self.X_indptr = X_indptr
|
||||
self.X_indices = X_indices
|
||||
self.Y = Y
|
||||
self.sample_weights = sample_weights
|
||||
|
||||
self.n_samples = Y.shape[0]
|
||||
self.current_index = -1
|
||||
self.X_data_ptr = <{{c_type}} *> &X_data[0]
|
||||
self.X_indptr_ptr = <int *> &X_indptr[0]
|
||||
self.X_indices_ptr = <int *> &X_indices[0]
|
||||
|
||||
self.Y_data_ptr = <{{c_type}} *> &Y[0]
|
||||
self.sample_weight_data = <{{c_type}} *> &sample_weights[0]
|
||||
|
||||
# Use index array for fast shuffling
|
||||
self.index = np.arange(self.n_samples, dtype=np.intc)
|
||||
self.index_data_ptr = <int *> &self.index[0]
|
||||
# seed should not be 0 for our_rand_r
|
||||
self.seed = max(seed, 1)
|
||||
|
||||
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
|
||||
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
|
||||
int current_index) noexcept nogil:
|
||||
cdef long long sample_idx = self.index_data_ptr[current_index]
|
||||
cdef long long offset = self.X_indptr_ptr[sample_idx]
|
||||
y[0] = self.Y_data_ptr[sample_idx]
|
||||
x_data_ptr[0] = self.X_data_ptr + offset
|
||||
x_ind_ptr[0] = self.X_indices_ptr + offset
|
||||
nnz[0] = self.X_indptr_ptr[sample_idx + 1] - offset
|
||||
sample_weight[0] = self.sample_weight_data[sample_idx]
|
||||
|
||||
|
||||
{{endfor}}
|
||||
460
venv/lib/python3.12/site-packages/sklearn/utils/_set_output.py
Normal file
460
venv/lib/python3.12/site-packages/sklearn/utils/_set_output.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import importlib
|
||||
from functools import wraps
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse import issparse
|
||||
|
||||
from .._config import get_config
|
||||
from ._available_if import available_if
|
||||
|
||||
|
||||
def check_library_installed(library):
|
||||
"""Check library is installed."""
|
||||
try:
|
||||
return importlib.import_module(library)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
f"Setting output container to '{library}' requires {library} to be"
|
||||
" installed"
|
||||
) from exc
|
||||
|
||||
|
||||
def get_columns(columns):
|
||||
if callable(columns):
|
||||
try:
|
||||
return columns()
|
||||
except Exception:
|
||||
return None
|
||||
return columns
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContainerAdapterProtocol(Protocol):
|
||||
container_lib: str
|
||||
|
||||
def create_container(self, X_output, X_original, columns, inplace=False):
|
||||
"""Create container from `X_output` with additional metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X_output : {ndarray, dataframe}
|
||||
Data to wrap.
|
||||
|
||||
X_original : {ndarray, dataframe}
|
||||
Original input dataframe. This is used to extract the metadata that should
|
||||
be passed to `X_output`, e.g. pandas row index.
|
||||
|
||||
columns : callable, ndarray, or None
|
||||
The column names or a callable that returns the column names. The
|
||||
callable is useful if the column names require some computation. If `None`,
|
||||
then no columns are passed to the container's constructor.
|
||||
|
||||
inplace : bool, default=False
|
||||
Whether or not we intend to modify `X_output` in-place. However, it does
|
||||
not guarantee that we return the same object if the in-place operation
|
||||
is not possible.
|
||||
|
||||
Returns
|
||||
-------
|
||||
wrapped_output : container_type
|
||||
`X_output` wrapped into the container type.
|
||||
"""
|
||||
|
||||
def is_supported_container(self, X):
|
||||
"""Return True if X is a supported container.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Xs: container
|
||||
Containers to be checked.
|
||||
|
||||
Returns
|
||||
-------
|
||||
is_supported_container : bool
|
||||
True if X is a supported container.
|
||||
"""
|
||||
|
||||
def rename_columns(self, X, columns):
|
||||
"""Rename columns in `X`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : container
|
||||
Container which columns is updated.
|
||||
|
||||
columns : ndarray of str
|
||||
Columns to update the `X`'s columns with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
updated_container : container
|
||||
Container with new names.
|
||||
"""
|
||||
|
||||
def hstack(self, Xs):
|
||||
"""Stack containers horizontally (column-wise).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Xs : list of containers
|
||||
List of containers to stack.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stacked_Xs : container
|
||||
Stacked containers.
|
||||
"""
|
||||
|
||||
|
||||
class PandasAdapter:
|
||||
container_lib = "pandas"
|
||||
|
||||
def create_container(self, X_output, X_original, columns, inplace=True):
|
||||
pd = check_library_installed("pandas")
|
||||
columns = get_columns(columns)
|
||||
|
||||
if not inplace or not isinstance(X_output, pd.DataFrame):
|
||||
# In all these cases, we need to create a new DataFrame
|
||||
|
||||
# Unfortunately, we cannot use `getattr(container, "index")`
|
||||
# because `list` exposes an `index` attribute.
|
||||
if isinstance(X_output, pd.DataFrame):
|
||||
index = X_output.index
|
||||
elif isinstance(X_original, (pd.DataFrame, pd.Series)):
|
||||
index = X_original.index
|
||||
else:
|
||||
index = None
|
||||
|
||||
# We don't pass columns here because it would intend columns selection
|
||||
# instead of renaming.
|
||||
X_output = pd.DataFrame(X_output, index=index, copy=not inplace)
|
||||
|
||||
if columns is not None:
|
||||
return self.rename_columns(X_output, columns)
|
||||
return X_output
|
||||
|
||||
def is_supported_container(self, X):
|
||||
pd = check_library_installed("pandas")
|
||||
return isinstance(X, pd.DataFrame)
|
||||
|
||||
def rename_columns(self, X, columns):
|
||||
# we cannot use `rename` since it takes a dictionary and at this stage we have
|
||||
# potentially duplicate column names in `X`
|
||||
X.columns = columns
|
||||
return X
|
||||
|
||||
def hstack(self, Xs):
|
||||
pd = check_library_installed("pandas")
|
||||
return pd.concat(Xs, axis=1)
|
||||
|
||||
|
||||
class PolarsAdapter:
|
||||
container_lib = "polars"
|
||||
|
||||
def create_container(self, X_output, X_original, columns, inplace=True):
|
||||
pl = check_library_installed("polars")
|
||||
columns = get_columns(columns)
|
||||
columns = columns.tolist() if isinstance(columns, np.ndarray) else columns
|
||||
|
||||
if not inplace or not isinstance(X_output, pl.DataFrame):
|
||||
# In all these cases, we need to create a new DataFrame
|
||||
return pl.DataFrame(X_output, schema=columns, orient="row")
|
||||
|
||||
if columns is not None:
|
||||
return self.rename_columns(X_output, columns)
|
||||
return X_output
|
||||
|
||||
def is_supported_container(self, X):
|
||||
pl = check_library_installed("polars")
|
||||
return isinstance(X, pl.DataFrame)
|
||||
|
||||
def rename_columns(self, X, columns):
|
||||
# we cannot use `rename` since it takes a dictionary and at this stage we have
|
||||
# potentially duplicate column names in `X`
|
||||
X.columns = columns
|
||||
return X
|
||||
|
||||
def hstack(self, Xs):
|
||||
pl = check_library_installed("polars")
|
||||
return pl.concat(Xs, how="horizontal")
|
||||
|
||||
|
||||
class ContainerAdaptersManager:
|
||||
def __init__(self):
|
||||
self.adapters = {}
|
||||
|
||||
@property
|
||||
def supported_outputs(self):
|
||||
return {"default"} | set(self.adapters)
|
||||
|
||||
def register(self, adapter):
|
||||
self.adapters[adapter.container_lib] = adapter
|
||||
|
||||
|
||||
ADAPTERS_MANAGER = ContainerAdaptersManager()
|
||||
ADAPTERS_MANAGER.register(PandasAdapter())
|
||||
ADAPTERS_MANAGER.register(PolarsAdapter())
|
||||
|
||||
|
||||
def _get_adapter_from_container(container):
|
||||
"""Get the adapter that knows how to handle such container.
|
||||
|
||||
See :class:`sklearn.utils._set_output.ContainerAdapterProtocol` for more
|
||||
details.
|
||||
"""
|
||||
module_name = container.__class__.__module__.split(".")[0]
|
||||
try:
|
||||
return ADAPTERS_MANAGER.adapters[module_name]
|
||||
except KeyError as exc:
|
||||
available_adapters = list(ADAPTERS_MANAGER.adapters.keys())
|
||||
raise ValueError(
|
||||
"The container does not have a registered adapter in scikit-learn. "
|
||||
f"Available adapters are: {available_adapters} while the container "
|
||||
f"provided is: {container!r}."
|
||||
) from exc
|
||||
|
||||
|
||||
def _get_container_adapter(method, estimator=None):
|
||||
"""Get container adapter."""
|
||||
dense_config = _get_output_config(method, estimator)["dense"]
|
||||
try:
|
||||
return ADAPTERS_MANAGER.adapters[dense_config]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_output_config(method, estimator=None):
|
||||
"""Get output config based on estimator and global configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : {"transform"}
|
||||
Estimator's method for which the output container is looked up.
|
||||
|
||||
estimator : estimator instance or None
|
||||
Estimator to get the output configuration from. If `None`, check global
|
||||
configuration is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
config : dict
|
||||
Dictionary with keys:
|
||||
|
||||
- "dense": specifies the dense container for `method`. This can be
|
||||
`"default"` or `"pandas"`.
|
||||
"""
|
||||
est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
|
||||
if method in est_sklearn_output_config:
|
||||
dense_config = est_sklearn_output_config[method]
|
||||
else:
|
||||
dense_config = get_config()[f"{method}_output"]
|
||||
|
||||
supported_outputs = ADAPTERS_MANAGER.supported_outputs
|
||||
if dense_config not in supported_outputs:
|
||||
raise ValueError(
|
||||
f"output config must be in {sorted(supported_outputs)}, got {dense_config}"
|
||||
)
|
||||
|
||||
return {"dense": dense_config}
|
||||
|
||||
|
||||
def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
|
||||
"""Wrap output with container based on an estimator's or global config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : {"transform"}
|
||||
Estimator's method to get container output for.
|
||||
|
||||
data_to_wrap : {ndarray, dataframe}
|
||||
Data to wrap with container.
|
||||
|
||||
original_input : {ndarray, dataframe}
|
||||
Original input of function.
|
||||
|
||||
estimator : estimator instance
|
||||
Estimator with to get the output configuration from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : {ndarray, dataframe}
|
||||
If the output config is "default" or the estimator is not configured
|
||||
for wrapping return `data_to_wrap` unchanged.
|
||||
If the output config is "pandas", return `data_to_wrap` as a pandas
|
||||
DataFrame.
|
||||
"""
|
||||
output_config = _get_output_config(method, estimator)
|
||||
|
||||
if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
|
||||
return data_to_wrap
|
||||
|
||||
dense_config = output_config["dense"]
|
||||
if issparse(data_to_wrap):
|
||||
raise ValueError(
|
||||
"The transformer outputs a scipy sparse matrix. "
|
||||
"Try to set the transformer output to a dense array or disable "
|
||||
f"{dense_config.capitalize()} output with set_output(transform='default')."
|
||||
)
|
||||
|
||||
adapter = ADAPTERS_MANAGER.adapters[dense_config]
|
||||
return adapter.create_container(
|
||||
data_to_wrap,
|
||||
original_input,
|
||||
columns=estimator.get_feature_names_out,
|
||||
)
|
||||
|
||||
|
||||
def _wrap_method_output(f, method):
|
||||
"""Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
|
||||
|
||||
@wraps(f)
|
||||
def wrapped(self, X, *args, **kwargs):
|
||||
data_to_wrap = f(self, X, *args, **kwargs)
|
||||
if isinstance(data_to_wrap, tuple):
|
||||
# only wrap the first output for cross decomposition
|
||||
return_tuple = (
|
||||
_wrap_data_with_container(method, data_to_wrap[0], X, self),
|
||||
*data_to_wrap[1:],
|
||||
)
|
||||
# Support for namedtuples `_make` is a documented API for namedtuples:
|
||||
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
|
||||
if hasattr(type(data_to_wrap), "_make"):
|
||||
return type(data_to_wrap)._make(return_tuple)
|
||||
return return_tuple
|
||||
|
||||
return _wrap_data_with_container(method, data_to_wrap, X, self)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _auto_wrap_is_configured(estimator):
|
||||
"""Return True if estimator is configured for auto-wrapping the transform method.
|
||||
|
||||
`_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
|
||||
is manually disabled.
|
||||
"""
|
||||
auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
|
||||
return (
|
||||
hasattr(estimator, "get_feature_names_out")
|
||||
and "transform" in auto_wrap_output_keys
|
||||
)
|
||||
|
||||
|
||||
class _SetOutputMixin:
|
||||
"""Mixin that dynamically wraps methods to return container based on config.
|
||||
|
||||
Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
|
||||
it based on `set_output` of the global configuration.
|
||||
|
||||
`set_output` is only defined if `get_feature_names_out` is defined and
|
||||
`auto_wrap_output_keys` is the default value.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# Dynamically wraps `transform` and `fit_transform` and configure it's
|
||||
# output based on `set_output`.
|
||||
if not (
|
||||
isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
|
||||
):
|
||||
raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
|
||||
|
||||
if auto_wrap_output_keys is None:
|
||||
cls._sklearn_auto_wrap_output_keys = set()
|
||||
return
|
||||
|
||||
# Mapping from method to key in configurations
|
||||
method_to_key = {
|
||||
"transform": "transform",
|
||||
"fit_transform": "transform",
|
||||
}
|
||||
cls._sklearn_auto_wrap_output_keys = set()
|
||||
|
||||
for method, key in method_to_key.items():
|
||||
if not hasattr(cls, method) or key not in auto_wrap_output_keys:
|
||||
continue
|
||||
cls._sklearn_auto_wrap_output_keys.add(key)
|
||||
|
||||
# Only wrap methods defined by cls itself
|
||||
if method not in cls.__dict__:
|
||||
continue
|
||||
wrapped_method = _wrap_method_output(getattr(cls, method), key)
|
||||
setattr(cls, method, wrapped_method)
|
||||
|
||||
@available_if(_auto_wrap_is_configured)
|
||||
def set_output(self, *, transform=None):
|
||||
"""Set output container.
|
||||
|
||||
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
|
||||
for an example on how to use the API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transform : {"default", "pandas", "polars"}, default=None
|
||||
Configure output of `transform` and `fit_transform`.
|
||||
|
||||
- `"default"`: Default output format of a transformer
|
||||
- `"pandas"`: DataFrame output
|
||||
- `"polars"`: Polars output
|
||||
- `None`: Transform configuration is unchanged
|
||||
|
||||
.. versionadded:: 1.4
|
||||
`"polars"` option was added.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : estimator instance
|
||||
Estimator instance.
|
||||
"""
|
||||
if transform is None:
|
||||
return self
|
||||
|
||||
if not hasattr(self, "_sklearn_output_config"):
|
||||
self._sklearn_output_config = {}
|
||||
|
||||
self._sklearn_output_config["transform"] = transform
|
||||
return self
|
||||
|
||||
|
||||
def _safe_set_output(estimator, *, transform=None):
|
||||
"""Safely call estimator.set_output and error if it not available.
|
||||
|
||||
This is used by meta-estimators to set the output for child estimators.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Estimator instance.
|
||||
|
||||
transform : {"default", "pandas", "polars"}, default=None
|
||||
Configure output of the following estimator's methods:
|
||||
|
||||
- `"transform"`
|
||||
- `"fit_transform"`
|
||||
|
||||
If `None`, this operation is a no-op.
|
||||
|
||||
Returns
|
||||
-------
|
||||
estimator : estimator instance
|
||||
Estimator instance.
|
||||
"""
|
||||
set_output_for_transform = hasattr(estimator, "transform") or (
|
||||
hasattr(estimator, "fit_transform") and transform is not None
|
||||
)
|
||||
if not set_output_for_transform:
|
||||
# If estimator can not transform, then `set_output` does not need to be
|
||||
# called.
|
||||
return
|
||||
|
||||
if not hasattr(estimator, "set_output"):
|
||||
raise ValueError(
|
||||
f"Unable to configure output for {estimator} because `set_output` "
|
||||
"is not available."
|
||||
)
|
||||
return estimator.set_output(transform=transform)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Utility methods to print system info for debugging
|
||||
|
||||
adapted from :func:`pandas.show_versions`
|
||||
"""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from threadpoolctl import threadpool_info
|
||||
|
||||
from .. import __version__
|
||||
from ._openmp_helpers import _openmp_parallelism_enabled
|
||||
|
||||
|
||||
def _get_sys_info():
|
||||
"""System information
|
||||
|
||||
Returns
|
||||
-------
|
||||
sys_info : dict
|
||||
system and Python version information
|
||||
|
||||
"""
|
||||
python = sys.version.replace("\n", " ")
|
||||
|
||||
blob = [
|
||||
("python", python),
|
||||
("executable", sys.executable),
|
||||
("machine", platform.platform()),
|
||||
]
|
||||
|
||||
return dict(blob)
|
||||
|
||||
|
||||
def _get_deps_info():
|
||||
"""Overview of the installed version of main dependencies
|
||||
|
||||
This function does not import the modules to collect the version numbers
|
||||
but instead relies on standard Python package metadata.
|
||||
|
||||
Returns
|
||||
-------
|
||||
deps_info: dict
|
||||
version information on relevant Python libraries
|
||||
|
||||
"""
|
||||
deps = [
|
||||
"pip",
|
||||
"setuptools",
|
||||
"numpy",
|
||||
"scipy",
|
||||
"Cython",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
"joblib",
|
||||
"threadpoolctl",
|
||||
]
|
||||
|
||||
deps_info = {
|
||||
"sklearn": __version__,
|
||||
}
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
for modname in deps:
|
||||
try:
|
||||
deps_info[modname] = version(modname)
|
||||
except PackageNotFoundError:
|
||||
deps_info[modname] = None
|
||||
return deps_info
|
||||
|
||||
|
||||
def show_versions():
|
||||
"""Print useful debugging information"
|
||||
|
||||
.. versionadded:: 0.20
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn import show_versions
|
||||
>>> show_versions() # doctest: +SKIP
|
||||
"""
|
||||
|
||||
sys_info = _get_sys_info()
|
||||
deps_info = _get_deps_info()
|
||||
|
||||
print("\nSystem:")
|
||||
for k, stat in sys_info.items():
|
||||
print("{k:>10}: {stat}".format(k=k, stat=stat))
|
||||
|
||||
print("\nPython dependencies:")
|
||||
for k, stat in deps_info.items():
|
||||
print("{k:>13}: {stat}".format(k=k, stat=stat))
|
||||
|
||||
print(
|
||||
"\n{k}: {stat}".format(
|
||||
k="Built with OpenMP", stat=_openmp_parallelism_enabled()
|
||||
)
|
||||
)
|
||||
|
||||
# show threadpoolctl results
|
||||
threadpool_results = threadpool_info()
|
||||
if threadpool_results:
|
||||
print()
|
||||
print("threadpoolctl info:")
|
||||
|
||||
for i, result in enumerate(threadpool_results):
|
||||
for key, val in result.items():
|
||||
print(f"{key:>15}: {val}")
|
||||
if i != len(threadpool_results) - 1:
|
||||
print()
|
||||
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
from ._typedefs cimport intp_t
|
||||
|
||||
from cython cimport floating
|
||||
|
||||
cdef int simultaneous_sort(
|
||||
floating *dist,
|
||||
intp_t *idx,
|
||||
intp_t size,
|
||||
) noexcept nogil
|
||||
93
venv/lib/python3.12/site-packages/sklearn/utils/_sorting.pyx
Normal file
93
venv/lib/python3.12/site-packages/sklearn/utils/_sorting.pyx
Normal file
@@ -0,0 +1,93 @@
|
||||
from cython cimport floating
|
||||
|
||||
cdef inline void dual_swap(
|
||||
floating* darr,
|
||||
intp_t *iarr,
|
||||
intp_t a,
|
||||
intp_t b,
|
||||
) noexcept nogil:
|
||||
"""Swap the values at index a and b of both darr and iarr"""
|
||||
cdef floating dtmp = darr[a]
|
||||
darr[a] = darr[b]
|
||||
darr[b] = dtmp
|
||||
|
||||
cdef intp_t itmp = iarr[a]
|
||||
iarr[a] = iarr[b]
|
||||
iarr[b] = itmp
|
||||
|
||||
|
||||
cdef int simultaneous_sort(
|
||||
floating* values,
|
||||
intp_t* indices,
|
||||
intp_t size,
|
||||
) noexcept nogil:
|
||||
"""
|
||||
Perform a recursive quicksort on the values array as to sort them ascendingly.
|
||||
This simultaneously performs the swaps on both the values and the indices arrays.
|
||||
|
||||
The numpy equivalent is:
|
||||
|
||||
def simultaneous_sort(dist, idx):
|
||||
i = np.argsort(dist)
|
||||
return dist[i], idx[i]
|
||||
|
||||
Notes
|
||||
-----
|
||||
Arrays are manipulated via a pointer to there first element and their size
|
||||
as to ease the processing of dynamically allocated buffers.
|
||||
"""
|
||||
# TODO: In order to support discrete distance metrics, we need to have a
|
||||
# simultaneous sort which breaks ties on indices when distances are identical.
|
||||
# The best might be using a std::stable_sort and a Comparator which might need
|
||||
# an Array of Structures (AoS) instead of the Structure of Arrays (SoA)
|
||||
# currently used.
|
||||
cdef:
|
||||
intp_t pivot_idx, i, store_idx
|
||||
floating pivot_val
|
||||
|
||||
# in the small-array case, do things efficiently
|
||||
if size <= 1:
|
||||
pass
|
||||
elif size == 2:
|
||||
if values[0] > values[1]:
|
||||
dual_swap(values, indices, 0, 1)
|
||||
elif size == 3:
|
||||
if values[0] > values[1]:
|
||||
dual_swap(values, indices, 0, 1)
|
||||
if values[1] > values[2]:
|
||||
dual_swap(values, indices, 1, 2)
|
||||
if values[0] > values[1]:
|
||||
dual_swap(values, indices, 0, 1)
|
||||
else:
|
||||
# Determine the pivot using the median-of-three rule.
|
||||
# The smallest of the three is moved to the beginning of the array,
|
||||
# the middle (the pivot value) is moved to the end, and the largest
|
||||
# is moved to the pivot index.
|
||||
pivot_idx = size // 2
|
||||
if values[0] > values[size - 1]:
|
||||
dual_swap(values, indices, 0, size - 1)
|
||||
if values[size - 1] > values[pivot_idx]:
|
||||
dual_swap(values, indices, size - 1, pivot_idx)
|
||||
if values[0] > values[size - 1]:
|
||||
dual_swap(values, indices, 0, size - 1)
|
||||
pivot_val = values[size - 1]
|
||||
|
||||
# Partition indices about pivot. At the end of this operation,
|
||||
# pivot_idx will contain the pivot value, everything to the left
|
||||
# will be smaller, and everything to the right will be larger.
|
||||
store_idx = 0
|
||||
for i in range(size - 1):
|
||||
if values[i] < pivot_val:
|
||||
dual_swap(values, indices, i, store_idx)
|
||||
store_idx += 1
|
||||
dual_swap(values, indices, store_idx, size - 1)
|
||||
pivot_idx = store_idx
|
||||
|
||||
# Recursively sort each side of the pivot
|
||||
if pivot_idx > 1:
|
||||
simultaneous_sort(values, indices, pivot_idx)
|
||||
if pivot_idx + 2 < size:
|
||||
simultaneous_sort(values + pivot_idx + 1,
|
||||
indices + pivot_idx + 1,
|
||||
size - pivot_idx - 1)
|
||||
return 0
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user