add read me

This commit is contained in:
2026-01-09 10:28:44 +11:00
commit edaf914b73
13417 changed files with 2952119 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
"""Various utilities to help with development."""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from ..exceptions import DataConversionWarning
from . import metadata_routing
from ._bunch import Bunch
from ._chunking import gen_batches, gen_even_slices
# Make _safe_indexing importable from here for backward compat as this particular
# helper is considered semi-private and typically very useful for third-party
# libraries that want to comply with scikit-learn's estimator API. In particular,
# _safe_indexing was included in our public API documentation despite the leading
# `_` in its name.
from ._indexing import (
_safe_indexing, # noqa: F401
resample,
shuffle,
)
from ._mask import safe_mask
from ._repr_html.base import _HTMLDocumentationLinkMixin # noqa: F401
from ._repr_html.estimator import estimator_html_repr
from ._tags import (
ClassifierTags,
InputTags,
RegressorTags,
Tags,
TargetTags,
TransformerTags,
get_tags,
)
from .class_weight import compute_class_weight, compute_sample_weight
from .deprecation import deprecated
from .discovery import all_estimators
from .extmath import safe_sqr
from .murmurhash import murmurhash3_32
from .validation import (
as_float_array,
assert_all_finite,
check_array,
check_consistent_length,
check_random_state,
check_scalar,
check_symmetric,
check_X_y,
column_or_1d,
indexable,
)
__all__ = [
"Bunch",
"ClassifierTags",
"DataConversionWarning",
"InputTags",
"RegressorTags",
"Tags",
"TargetTags",
"TransformerTags",
"all_estimators",
"as_float_array",
"assert_all_finite",
"check_X_y",
"check_array",
"check_consistent_length",
"check_random_state",
"check_scalar",
"check_symmetric",
"column_or_1d",
"compute_class_weight",
"compute_sample_weight",
"deprecated",
"estimator_html_repr",
"gen_batches",
"gen_even_slices",
"get_tags",
"indexable",
"metadata_routing",
"murmurhash3_32",
"resample",
"safe_mask",
"safe_sqr",
"shuffle",
]

View File

@@ -0,0 +1,33 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from .validation import check_random_state
def _init_arpack_v0(size, random_state):
"""Initialize the starting vector for iteration in ARPACK functions.
Initialize a ndarray with values sampled from the uniform distribution on
[-1, 1]. This initialization model has been chosen to be consistent with
the ARPACK one as another initialization can lead to convergence issues.
Parameters
----------
size : int
The size of the eigenvalue vector to be initialized.
random_state : int, RandomState instance or None, default=None
The seed of the pseudo random number generator used to generate a
uniform distribution. If int, random_state is the seed used by the
random number generator; If RandomState instance, random_state is the
random number generator; If None, the random number generator is the
RandomState instance used by `np.random`.
Returns
-------
v0 : ndarray of shape (size,)
The initialized vector.
"""
random_state = check_random_state(random_state)
v0 = random_state.uniform(-1, 1, size)
return v0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from functools import update_wrapper, wraps
from types import MethodType
class _AvailableIfDescriptor:
"""Implements a conditional property using the descriptor protocol.
Using this class to create a decorator will raise an ``AttributeError``
if check(self) returns a falsey value. Note that if check raises an error
this will also result in hasattr returning false.
See https://docs.python.org/3/howto/descriptor.html for an explanation of
descriptors.
"""
def __init__(self, fn, check, attribute_name):
self.fn = fn
self.check = check
self.attribute_name = attribute_name
# update the docstring of the descriptor
update_wrapper(self, fn)
def _check(self, obj, owner):
attr_err_msg = (
f"This {owner.__name__!r} has no attribute {self.attribute_name!r}"
)
try:
check_result = self.check(obj)
except Exception as e:
raise AttributeError(attr_err_msg) from e
if not check_result:
raise AttributeError(attr_err_msg)
def __get__(self, obj, owner=None):
if obj is not None:
# delegate only on instances, not the classes.
# this is to allow access to the docstrings.
self._check(obj, owner=owner)
out = MethodType(self.fn, obj)
else:
# This makes it possible to use the decorated method as an unbound method,
# for instance when monkeypatching.
@wraps(self.fn)
def out(*args, **kwargs):
self._check(args[0], owner=owner)
return self.fn(*args, **kwargs)
return out
def available_if(check):
"""An attribute that is available only if check returns a truthy value.
Parameters
----------
check : callable
When passed the object with the decorated method, this should return
a truthy value if the attribute is available, and either return False
or raise an AttributeError if not available.
Returns
-------
callable
Callable makes the decorated method available if `check` returns
a truthy value, otherwise the decorated method is unavailable.
Examples
--------
>>> from sklearn.utils.metaestimators import available_if
>>> class HelloIfEven:
... def __init__(self, x):
... self.x = x
...
... def _x_is_even(self):
... return self.x % 2 == 0
...
... @available_if(_x_is_even)
... def say_hello(self):
... print("Hello")
...
>>> obj = HelloIfEven(1)
>>> hasattr(obj, "say_hello")
False
>>> obj.x = 2
>>> hasattr(obj, "say_hello")
True
>>> obj.say_hello()
Hello
"""
return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__)

View File

@@ -0,0 +1,70 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
class Bunch(dict):
"""Container object exposing keys as attributes.
Bunch objects are sometimes used as an output for functions and methods.
They extend dictionaries by enabling values to be accessed by key,
`bunch["value_key"]`, or by an attribute, `bunch.value_key`.
Examples
--------
>>> from sklearn.utils import Bunch
>>> b = Bunch(a=1, b=2)
>>> b['b']
2
>>> b.b
2
>>> b.a = 3
>>> b['a']
3
>>> b.c = 6
>>> b['c']
6
"""
def __init__(self, **kwargs):
super().__init__(kwargs)
# Map from deprecated key to warning message
self.__dict__["_deprecated_key_to_warnings"] = {}
def __getitem__(self, key):
if key in self.__dict__.get("_deprecated_key_to_warnings", {}):
warnings.warn(
self._deprecated_key_to_warnings[key],
FutureWarning,
)
return super().__getitem__(key)
def _set_deprecated(self, value, *, new_key, deprecated_key, warning_message):
"""Set key in dictionary to be deprecated with its warning message."""
self.__dict__["_deprecated_key_to_warnings"][deprecated_key] = warning_message
self[new_key] = self[deprecated_key] = value
def __setattr__(self, key, value):
self[key] = value
def __dir__(self):
return self.keys()
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def __setstate__(self, state):
# Bunch pickles generated with scikit-learn 0.16.* have an non
# empty __dict__. This causes a surprising behaviour when
# loading these pickles scikit-learn 0.17: reading bunch.key
# uses __dict__ but assigning to bunch.key use __setattr__ and
# only changes bunch['key']. More details can be found at:
# https://github.com/scikit-learn/scikit-learn/issues/6196.
# Overriding __setstate__ to be a noop has the effect of
# ignoring the pickled __dict__
pass

View File

@@ -0,0 +1,178 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from itertools import islice
from numbers import Integral
import numpy as np
from .._config import get_config
from ._param_validation import Interval, validate_params
def chunk_generator(gen, chunksize):
"""Chunk generator, ``gen`` into lists of length ``chunksize``. The last
chunk may have a length less than ``chunksize``."""
while True:
chunk = list(islice(gen, chunksize))
if chunk:
yield chunk
else:
return
@validate_params(
{
"n": [Interval(Integral, 1, None, closed="left")],
"batch_size": [Interval(Integral, 1, None, closed="left")],
"min_batch_size": [Interval(Integral, 0, None, closed="left")],
},
prefer_skip_nested_validation=True,
)
def gen_batches(n, batch_size, *, min_batch_size=0):
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
The last slice may contain less than `batch_size` elements, when
`batch_size` does not divide `n`.
Parameters
----------
n : int
Size of the sequence.
batch_size : int
Number of elements in each batch.
min_batch_size : int, default=0
Minimum number of elements in each batch.
Yields
------
slice of `batch_size` elements
See Also
--------
gen_even_slices: Generator to create n_packs slices going up to n.
Examples
--------
>>> from sklearn.utils import gen_batches
>>> list(gen_batches(7, 3))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(6, 3))
[slice(0, 3, None), slice(3, 6, None)]
>>> list(gen_batches(2, 3))
[slice(0, 2, None)]
>>> list(gen_batches(7, 3, min_batch_size=0))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(7, 3, min_batch_size=2))
[slice(0, 3, None), slice(3, 7, None)]
"""
start = 0
for _ in range(int(n // batch_size)):
end = start + batch_size
if end + min_batch_size > n:
continue
yield slice(start, end)
start = end
if start < n:
yield slice(start, n)
@validate_params(
{
"n": [Interval(Integral, 1, None, closed="left")],
"n_packs": [Interval(Integral, 1, None, closed="left")],
"n_samples": [Interval(Integral, 1, None, closed="left"), None],
},
prefer_skip_nested_validation=True,
)
def gen_even_slices(n, n_packs, *, n_samples=None):
"""Generator to create `n_packs` evenly spaced slices going up to `n`.
If `n_packs` does not divide `n`, except for the first `n % n_packs`
slices, remaining slices may contain fewer elements.
Parameters
----------
n : int
Size of the sequence.
n_packs : int
Number of slices to generate.
n_samples : int, default=None
Number of samples. Pass `n_samples` when the slices are to be used for
sparse matrix indexing; slicing off-the-end raises an exception, while
it works for NumPy arrays.
Yields
------
`slice` representing a set of indices from 0 to n.
See Also
--------
gen_batches: Generator to create slices containing batch_size elements
from 0 to n.
Examples
--------
>>> from sklearn.utils import gen_even_slices
>>> list(gen_even_slices(10, 1))
[slice(0, 10, None)]
>>> list(gen_even_slices(10, 10))
[slice(0, 1, None), slice(1, 2, None), ..., slice(9, 10, None)]
>>> list(gen_even_slices(10, 5))
[slice(0, 2, None), slice(2, 4, None), ..., slice(8, 10, None)]
>>> list(gen_even_slices(10, 3))
[slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)]
"""
start = 0
for pack_num in range(n_packs):
this_n = n // n_packs
if pack_num < n % n_packs:
this_n += 1
if this_n > 0:
end = start + this_n
if n_samples is not None:
end = min(n_samples, end)
yield slice(start, end, None)
start = end
def get_chunk_n_rows(row_bytes, *, max_n_rows=None, working_memory=None):
"""Calculate how many rows can be processed within `working_memory`.
Parameters
----------
row_bytes : int
The expected number of bytes of memory that will be consumed
during the processing of each row.
max_n_rows : int, default=None
The maximum return value.
working_memory : int or float, default=None
The number of rows to fit inside this number of MiB will be
returned. When None (default), the value of
``sklearn.get_config()['working_memory']`` is used.
Returns
-------
int
The number of rows which can be processed within `working_memory`.
Warns
-----
Issues a UserWarning if `row_bytes exceeds `working_memory` MiB.
"""
if working_memory is None:
working_memory = get_config()["working_memory"]
chunk_n_rows = int(working_memory * (2**20) // row_bytes)
if max_n_rows is not None:
chunk_n_rows = min(chunk_n_rows, max_n_rows)
if chunk_n_rows < 1:
warnings.warn(
"Could not adhere to working_memory config. "
"Currently %.0fMiB, %.0fMiB required."
% (working_memory, np.ceil(row_bytes * 2**-20))
)
chunk_n_rows = 1
return chunk_n_rows

View File

@@ -0,0 +1,41 @@
from cython cimport floating
cpdef enum BLAS_Order:
RowMajor # C contiguous
ColMajor # Fortran contiguous
cpdef enum BLAS_Trans:
NoTrans = 110 # correspond to 'n'
Trans = 116 # correspond to 't'
# BLAS Level 1 ################################################################
cdef floating _dot(int, const floating*, int, const floating*, int) noexcept nogil
cdef floating _asum(int, const floating*, int) noexcept nogil
cdef void _axpy(int, floating, const floating*, int, floating*, int) noexcept nogil
cdef floating _nrm2(int, const floating*, int) noexcept nogil
cdef void _copy(int, const floating*, int, const floating*, int) noexcept nogil
cdef void _scal(int, floating, const floating*, int) noexcept nogil
cdef void _rotg(floating*, floating*, floating*, floating*) noexcept nogil
cdef void _rot(int, floating*, int, floating*, int, floating, floating) noexcept nogil
# BLAS Level 2 ################################################################
cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, const floating*, int,
const floating*, int, floating, floating*, int) noexcept nogil
cdef void _ger(BLAS_Order, int, int, floating, const floating*, int, const floating*,
int, floating*, int) noexcept nogil
# BLASLevel 3 ################################################################
cdef void _gemm(BLAS_Order, BLAS_Trans, BLAS_Trans, int, int, int, floating,
const floating*, int, const floating*, int, floating, floating*,
int) noexcept nogil

View File

@@ -0,0 +1,239 @@
from cython cimport floating
from scipy.linalg.cython_blas cimport sdot, ddot
from scipy.linalg.cython_blas cimport sasum, dasum
from scipy.linalg.cython_blas cimport saxpy, daxpy
from scipy.linalg.cython_blas cimport snrm2, dnrm2
from scipy.linalg.cython_blas cimport scopy, dcopy
from scipy.linalg.cython_blas cimport sscal, dscal
from scipy.linalg.cython_blas cimport srotg, drotg
from scipy.linalg.cython_blas cimport srot, drot
from scipy.linalg.cython_blas cimport sgemv, dgemv
from scipy.linalg.cython_blas cimport sger, dger
from scipy.linalg.cython_blas cimport sgemm, dgemm
################
# BLAS Level 1 #
################
cdef floating _dot(int n, const floating *x, int incx,
const floating *y, int incy) noexcept nogil:
"""x.T.y"""
if floating is float:
return sdot(&n, <float *> x, &incx, <float *> y, &incy)
else:
return ddot(&n, <double *> x, &incx, <double *> y, &incy)
cpdef _dot_memview(const floating[::1] x, const floating[::1] y):
return _dot(x.shape[0], &x[0], 1, &y[0], 1)
cdef floating _asum(int n, const floating *x, int incx) noexcept nogil:
"""sum(|x_i|)"""
if floating is float:
return sasum(&n, <float *> x, &incx)
else:
return dasum(&n, <double *> x, &incx)
cpdef _asum_memview(const floating[::1] x):
return _asum(x.shape[0], &x[0], 1)
cdef void _axpy(int n, floating alpha, const floating *x, int incx,
floating *y, int incy) noexcept nogil:
"""y := alpha * x + y"""
if floating is float:
saxpy(&n, &alpha, <float *> x, &incx, y, &incy)
else:
daxpy(&n, &alpha, <double *> x, &incx, y, &incy)
cpdef _axpy_memview(floating alpha, const floating[::1] x, floating[::1] y):
_axpy(x.shape[0], alpha, &x[0], 1, &y[0], 1)
cdef floating _nrm2(int n, const floating *x, int incx) noexcept nogil:
"""sqrt(sum((x_i)^2))"""
if floating is float:
return snrm2(&n, <float *> x, &incx)
else:
return dnrm2(&n, <double *> x, &incx)
cpdef _nrm2_memview(const floating[::1] x):
return _nrm2(x.shape[0], &x[0], 1)
cdef void _copy(int n, const floating *x, int incx, const floating *y, int incy) noexcept nogil:
"""y := x"""
if floating is float:
scopy(&n, <float *> x, &incx, <float *> y, &incy)
else:
dcopy(&n, <double *> x, &incx, <double *> y, &incy)
cpdef _copy_memview(const floating[::1] x, const floating[::1] y):
_copy(x.shape[0], &x[0], 1, &y[0], 1)
cdef void _scal(int n, floating alpha, const floating *x, int incx) noexcept nogil:
"""x := alpha * x"""
if floating is float:
sscal(&n, &alpha, <float *> x, &incx)
else:
dscal(&n, &alpha, <double *> x, &incx)
cpdef _scal_memview(floating alpha, const floating[::1] x):
_scal(x.shape[0], alpha, &x[0], 1)
cdef void _rotg(floating *a, floating *b, floating *c, floating *s) noexcept nogil:
"""Generate plane rotation"""
if floating is float:
srotg(a, b, c, s)
else:
drotg(a, b, c, s)
cpdef _rotg_memview(floating a, floating b, floating c, floating s):
_rotg(&a, &b, &c, &s)
return a, b, c, s
cdef void _rot(int n, floating *x, int incx, floating *y, int incy,
floating c, floating s) noexcept nogil:
"""Apply plane rotation"""
if floating is float:
srot(&n, x, &incx, y, &incy, &c, &s)
else:
drot(&n, x, &incx, y, &incy, &c, &s)
cpdef _rot_memview(floating[::1] x, floating[::1] y, floating c, floating s):
_rot(x.shape[0], &x[0], 1, &y[0], 1, c, s)
################
# BLAS Level 2 #
################
cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
const floating *A, int lda, const floating *x, int incx,
floating beta, floating *y, int incy) noexcept nogil:
"""y := alpha * op(A).x + beta * y"""
cdef char ta_ = ta
if order == BLAS_Order.RowMajor:
ta_ = BLAS_Trans.NoTrans if ta == BLAS_Trans.Trans else BLAS_Trans.Trans
if floating is float:
sgemv(&ta_, &n, &m, &alpha, <float *> A, &lda, <float *> x,
&incx, &beta, y, &incy)
else:
dgemv(&ta_, &n, &m, &alpha, <double *> A, &lda, <double *> x,
&incx, &beta, y, &incy)
else:
if floating is float:
sgemv(&ta_, &m, &n, &alpha, <float *> A, &lda, <float *> x,
&incx, &beta, y, &incy)
else:
dgemv(&ta_, &m, &n, &alpha, <double *> A, &lda, <double *> x,
&incx, &beta, y, &incy)
cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A,
const floating[::1] x, floating beta, floating[::1] y):
cdef:
int m = A.shape[0]
int n = A.shape[1]
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)
int lda = m if order == BLAS_Order.ColMajor else n
_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1)
cdef void _ger(BLAS_Order order, int m, int n, floating alpha,
const floating *x, int incx, const floating *y,
int incy, floating *A, int lda) noexcept nogil:
"""A := alpha * x.y.T + A"""
if order == BLAS_Order.RowMajor:
if floating is float:
sger(&n, &m, &alpha, <float *> y, &incy, <float *> x, &incx, A, &lda)
else:
dger(&n, &m, &alpha, <double *> y, &incy, <double *> x, &incx, A, &lda)
else:
if floating is float:
sger(&m, &n, &alpha, <float *> x, &incx, <float *> y, &incy, A, &lda)
else:
dger(&m, &n, &alpha, <double *> x, &incx, <double *> y, &incy, A, &lda)
cpdef _ger_memview(floating alpha, const floating[::1] x,
const floating[::1] y, floating[:, :] A):
cdef:
int m = A.shape[0]
int n = A.shape[1]
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)
int lda = m if order == BLAS_Order.ColMajor else n
_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)
################
# BLAS Level 3 #
################
cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
int k, floating alpha, const floating *A, int lda, const floating *B,
int ldb, floating beta, floating *C, int ldc) noexcept nogil:
"""C := alpha * op(A).op(B) + beta * C"""
# TODO: Remove the pointer casts below once SciPy uses const-qualification.
# See: https://github.com/scipy/scipy/issues/14262
cdef:
char ta_ = ta
char tb_ = tb
if order == BLAS_Order.RowMajor:
if floating is float:
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, <float*>B,
&ldb, <float*>A, &lda, &beta, C, &ldc)
else:
dgemm(&tb_, &ta_, &n, &m, &k, &alpha, <double*>B,
&ldb, <double*>A, &lda, &beta, C, &ldc)
else:
if floating is float:
sgemm(&ta_, &tb_, &m, &n, &k, &alpha, <float*>A,
&lda, <float*>B, &ldb, &beta, C, &ldc)
else:
dgemm(&ta_, &tb_, &m, &n, &k, &alpha, <double*>A,
&lda, <double*>B, &ldb, &beta, C, &ldc)
cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
const floating[:, :] A, const floating[:, :] B, floating beta,
floating[:, :] C):
cdef:
int m = A.shape[0] if ta == BLAS_Trans.NoTrans else A.shape[1]
int n = B.shape[1] if tb == BLAS_Trans.NoTrans else B.shape[0]
int k = A.shape[1] if ta == BLAS_Trans.NoTrans else A.shape[0]
int lda, ldb, ldc
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)
if order == BLAS_Order.RowMajor:
lda = k if ta == BLAS_Trans.NoTrans else m
ldb = n if tb == BLAS_Trans.NoTrans else k
ldc = n
else:
lda = m if ta == BLAS_Trans.NoTrans else k
ldb = k if tb == BLAS_Trans.NoTrans else n
ldc = m
_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],
lda, &B[0, 0], ldb, beta, &C[0, 0], ldc)

View File

@@ -0,0 +1,376 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from collections import Counter
from contextlib import suppress
from typing import NamedTuple
import numpy as np
from ._array_api import (
_isin,
_searchsorted,
device,
get_namespace,
xpx,
)
from ._missing import is_scalar_nan
def _unique(values, *, return_inverse=False, return_counts=False):
"""Helper function to find unique values with support for python objects.
Uses pure python method for object dtype, and numpy method for
all other dtypes.
Parameters
----------
values : ndarray
Values to check for unknowns.
return_inverse : bool, default=False
If True, also return the indices of the unique values.
return_counts : bool, default=False
If True, also return the number of times each unique item appears in
values.
Returns
-------
unique : ndarray
The sorted unique values.
unique_inverse : ndarray
The indices to reconstruct the original array from the unique array.
Only provided if `return_inverse` is True.
unique_counts : ndarray
The number of times each of the unique values comes up in the original
array. Only provided if `return_counts` is True.
"""
if values.dtype == object:
return _unique_python(
values, return_inverse=return_inverse, return_counts=return_counts
)
# numerical
return _unique_np(
values, return_inverse=return_inverse, return_counts=return_counts
)
def _unique_np(values, return_inverse=False, return_counts=False):
"""Helper function to find unique values for numpy arrays that correctly
accounts for nans. See `_unique` documentation for details."""
xp, _ = get_namespace(values)
inverse, counts = None, None
if return_inverse and return_counts:
uniques, _, inverse, counts = xp.unique_all(values)
elif return_inverse:
uniques, inverse = xp.unique_inverse(values)
elif return_counts:
uniques, counts = xp.unique_counts(values)
else:
uniques = xp.unique_values(values)
# np.unique will have duplicate missing values at the end of `uniques`
# here we clip the nans and remove it from uniques
if uniques.size and is_scalar_nan(uniques[-1]):
nan_idx = _searchsorted(uniques, xp.nan, xp=xp)
uniques = uniques[: nan_idx + 1]
if return_inverse:
inverse[inverse > nan_idx] = nan_idx
if return_counts:
counts[nan_idx] = xp.sum(counts[nan_idx:])
counts = counts[: nan_idx + 1]
ret = (uniques,)
if return_inverse:
ret += (inverse,)
if return_counts:
ret += (counts,)
return ret[0] if len(ret) == 1 else ret
class MissingValues(NamedTuple):
"""Data class for missing data information"""
nan: bool
none: bool
def to_list(self):
"""Convert tuple to a list where None is always first."""
output = []
if self.none:
output.append(None)
if self.nan:
output.append(np.nan)
return output
def _extract_missing(values):
"""Extract missing values from `values`.
Parameters
----------
values: set
Set of values to extract missing from.
Returns
-------
output: set
Set with missing values extracted.
missing_values: MissingValues
Object with missing value information.
"""
missing_values_set = {
value for value in values if value is None or is_scalar_nan(value)
}
if not missing_values_set:
return values, MissingValues(nan=False, none=False)
if None in missing_values_set:
if len(missing_values_set) == 1:
output_missing_values = MissingValues(nan=False, none=True)
else:
# If there is more than one missing value, then it has to be
# float('nan') or np.nan
output_missing_values = MissingValues(nan=True, none=True)
else:
output_missing_values = MissingValues(nan=True, none=False)
# create set without the missing values
output = values - missing_values_set
return output, output_missing_values
class _nandict(dict):
"""Dictionary with support for nans."""
def __init__(self, mapping):
super().__init__(mapping)
for key, value in mapping.items():
if is_scalar_nan(key):
self.nan_value = value
break
def __missing__(self, key):
if hasattr(self, "nan_value") and is_scalar_nan(key):
return self.nan_value
raise KeyError(key)
def _map_to_integer(values, uniques):
"""Map values based on its position in uniques."""
xp, _ = get_namespace(values, uniques)
table = _nandict({val: i for i, val in enumerate(uniques)})
return xp.asarray([table[v] for v in values], device=device(values))
def _unique_python(values, *, return_inverse, return_counts):
# Only used in `_uniques`, see docstring there for details
try:
uniques_set = set(values)
uniques_set, missing_values = _extract_missing(uniques_set)
uniques = sorted(uniques_set)
uniques.extend(missing_values.to_list())
uniques = np.array(uniques, dtype=values.dtype)
except TypeError:
types = sorted(t.__qualname__ for t in set(type(v) for v in values))
raise TypeError(
"Encoders require their input argument must be uniformly "
f"strings or numbers. Got {types}"
)
ret = (uniques,)
if return_inverse:
ret += (_map_to_integer(values, uniques),)
if return_counts:
ret += (_get_counts(values, uniques),)
return ret[0] if len(ret) == 1 else ret
def _encode(values, *, uniques, check_unknown=True):
"""Helper function to encode values into [0, n_uniques - 1].
Uses pure python method for object dtype, and numpy method for
all other dtypes.
The numpy method has the limitation that the `uniques` need to
be sorted. Importantly, this is not checked but assumed to already be
the case. The calling method needs to ensure this for all non-object
values.
Parameters
----------
values : ndarray
Values to encode.
uniques : ndarray
The unique values in `values`. If the dtype is not object, then
`uniques` needs to be sorted.
check_unknown : bool, default=True
If True, check for values in `values` that are not in `unique`
and raise an error. This is ignored for object dtype, and treated as
True in this case. This parameter is useful for
_BaseEncoder._transform() to avoid calling _check_unknown()
twice.
Returns
-------
encoded : ndarray
Encoded values
"""
xp, _ = get_namespace(values, uniques)
if not xp.isdtype(values.dtype, "numeric"):
try:
return _map_to_integer(values, uniques)
except KeyError as e:
raise ValueError(f"y contains previously unseen labels: {e}")
else:
if check_unknown:
diff = _check_unknown(values, uniques)
if diff:
raise ValueError(f"y contains previously unseen labels: {diff}")
return _searchsorted(uniques, values, xp=xp)
def _check_unknown(values, known_values, return_mask=False):
"""
Helper function to check for unknowns in values to be encoded.
Uses pure python method for object dtype, and numpy method for
all other dtypes.
Parameters
----------
values : array
Values to check for unknowns.
known_values : array
Known values. Must be unique.
return_mask : bool, default=False
If True, return a mask of the same shape as `values` indicating
the valid values.
Returns
-------
diff : list
The unique values present in `values` and not in `know_values`.
valid_mask : boolean array
Additionally returned if ``return_mask=True``.
"""
xp, _ = get_namespace(values, known_values)
valid_mask = None
if not xp.isdtype(values.dtype, "numeric"):
values_set = set(values)
values_set, missing_in_values = _extract_missing(values_set)
uniques_set = set(known_values)
uniques_set, missing_in_uniques = _extract_missing(uniques_set)
diff = values_set - uniques_set
nan_in_diff = missing_in_values.nan and not missing_in_uniques.nan
none_in_diff = missing_in_values.none and not missing_in_uniques.none
def is_valid(value):
return (
value in uniques_set
or (missing_in_uniques.none and value is None)
or (missing_in_uniques.nan and is_scalar_nan(value))
)
if return_mask:
if diff or nan_in_diff or none_in_diff:
valid_mask = xp.array([is_valid(value) for value in values])
else:
valid_mask = xp.ones(len(values), dtype=xp.bool)
diff = list(diff)
if none_in_diff:
diff.append(None)
if nan_in_diff:
diff.append(np.nan)
else:
unique_values = xp.unique_values(values)
diff = xpx.setdiff1d(unique_values, known_values, assume_unique=True, xp=xp)
if return_mask:
if diff.size:
valid_mask = _isin(values, known_values, xp)
else:
valid_mask = xp.ones(len(values), dtype=xp.bool)
# check for nans in the known_values
if xp.any(xp.isnan(known_values)):
diff_is_nan = xp.isnan(diff)
if xp.any(diff_is_nan):
# removes nan from valid_mask
if diff.size and return_mask:
is_nan = xp.isnan(values)
valid_mask[is_nan] = 1
# remove nan from diff
diff = diff[~diff_is_nan]
diff = list(diff)
if return_mask:
return diff, valid_mask
return diff
class _NaNCounter(Counter):
"""Counter with support for nan values."""
def __init__(self, items):
super().__init__(self._generate_items(items))
def _generate_items(self, items):
"""Generate items without nans. Stores the nan counts separately."""
for item in items:
if not is_scalar_nan(item):
yield item
continue
if not hasattr(self, "nan_count"):
self.nan_count = 0
self.nan_count += 1
def __missing__(self, key):
if hasattr(self, "nan_count") and is_scalar_nan(key):
return self.nan_count
raise KeyError(key)
def _get_counts(values, uniques):
"""Get the count of each of the `uniques` in `values`.
The counts will use the order passed in by `uniques`. For non-object dtypes,
`uniques` is assumed to be sorted and `np.nan` is at the end.
"""
if values.dtype.kind in "OU":
counter = _NaNCounter(values)
output = np.zeros(len(uniques), dtype=np.int64)
for i, item in enumerate(uniques):
with suppress(KeyError):
output[i] = counter[item]
return output
unique_values, counts = _unique_np(values, return_counts=True)
# Recorder unique_values based on input: `uniques`
uniques_in_values = np.isin(uniques, unique_values, assume_unique=True)
if np.isnan(unique_values[-1]) and np.isnan(uniques[-1]):
uniques_in_values[-1] = True
unique_valid_indices = np.searchsorted(unique_values, uniques[uniques_in_values])
output = np.zeros_like(uniques, dtype=np.int64)
output[uniques_in_values] = counts[unique_valid_indices]
return output

View File

@@ -0,0 +1,34 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from ._repr_html.base import _HTMLDocumentationLinkMixin
from ._repr_html.estimator import (
_get_visual_block,
_IDCounter,
_VisualBlock,
_write_estimator_html,
_write_label_html,
estimator_html_repr,
)
__all__ = [
"_HTMLDocumentationLinkMixin",
"_IDCounter",
"_VisualBlock",
"_get_visual_block",
"_write_estimator_html",
"_write_label_html",
"estimator_html_repr",
]
# TODO(1.8): Remove the entire module
warnings.warn(
"Importing from sklearn.utils._estimator_html_repr is deprecated. The tools have "
"been moved to sklearn.utils._repr_html. Be aware that this module is private and "
"may be subject to change in the future. The module _estimator_html_repr will be "
"removed in 1.8.0.",
FutureWarning,
stacklevel=2,
)

View File

@@ -0,0 +1,19 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
"""
Uses C++ map containers for fast dict-like behavior with keys being
integers, and values float.
"""
from libcpp.map cimport map as cpp_map
from ._typedefs cimport float64_t, intp_t
###############################################################################
# An object to be used in Python
cdef class IntFloatDict:
cdef cpp_map[intp_t, float64_t] my_map
cdef _to_arrays(self, intp_t [:] keys, float64_t [:] values)

View File

@@ -0,0 +1,137 @@
"""
Uses C++ map containers for fast dict-like behavior with keys being
integers, and values float.
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# C++
from cython.operator cimport dereference as deref, preincrement as inc
from libcpp.utility cimport pair
from libcpp.map cimport map as cpp_map
import numpy as np
from ._typedefs cimport float64_t, intp_t
###############################################################################
# An object to be used in Python
# Lookup is faster than dict (up to 10 times), and so is full traversal
# (up to 50 times), and assignment (up to 6 times), but creation is
# slower (up to 3 times). Also, a large benefit is that memory
# consumption is reduced a lot compared to a Python dict
cdef class IntFloatDict:
def __init__(
self,
intp_t[:] keys,
float64_t[:] values,
):
cdef int i
cdef int size = values.size
# Should check that sizes for keys and values are equal, and
# after should boundcheck(False)
for i in range(size):
self.my_map[keys[i]] = values[i]
def __len__(self):
return self.my_map.size()
def __getitem__(self, int key):
cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.find(key)
if it == self.my_map.end():
# The key is not in the dict
raise KeyError('%i' % key)
return deref(it).second
def __setitem__(self, int key, float value):
self.my_map[key] = value
# Cython 0.20 generates buggy code below. Commenting this out for now
# and relying on the to_arrays method
# def __iter__(self):
# cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.begin()
# cdef cpp_map[intp_t, float64_t].iterator end = self.my_map.end()
# while it != end:
# yield deref(it).first, deref(it).second
# inc(it)
def __iter__(self):
cdef int size = self.my_map.size()
cdef intp_t [:] keys = np.empty(size, dtype=np.intp)
cdef float64_t [:] values = np.empty(size, dtype=np.float64)
self._to_arrays(keys, values)
cdef int idx
cdef intp_t key
cdef float64_t value
for idx in range(size):
key = keys[idx]
value = values[idx]
yield key, value
def to_arrays(self):
"""Return the key, value representation of the IntFloatDict
object.
Returns
=======
keys : ndarray, shape (n_items, ), dtype=int
The indices of the data points
values : ndarray, shape (n_items, ), dtype=float
The values of the data points
"""
cdef int size = self.my_map.size()
keys = np.empty(size, dtype=np.intp)
values = np.empty(size, dtype=np.float64)
self._to_arrays(keys, values)
return keys, values
cdef _to_arrays(self, intp_t [:] keys, float64_t [:] values):
# Internal version of to_arrays that takes already-initialized arrays
cdef cpp_map[intp_t, float64_t].iterator it = self.my_map.begin()
cdef cpp_map[intp_t, float64_t].iterator end = self.my_map.end()
cdef int index = 0
while it != end:
keys[index] = deref(it).first
values[index] = deref(it).second
inc(it)
index += 1
def update(self, IntFloatDict other):
cdef cpp_map[intp_t, float64_t].iterator it = other.my_map.begin()
cdef cpp_map[intp_t, float64_t].iterator end = other.my_map.end()
while it != end:
self.my_map[deref(it).first] = deref(it).second
inc(it)
def copy(self):
cdef IntFloatDict out_obj = IntFloatDict.__new__(IntFloatDict)
# The '=' operator is a copy operator for C++ maps
out_obj.my_map = self.my_map
return out_obj
def append(self, intp_t key, float64_t value):
# Construct our arguments
cdef pair[intp_t, float64_t] args
args.first = key
args.second = value
self.my_map.insert(args)
###############################################################################
# operation on dict
def argmin(IntFloatDict d):
cdef cpp_map[intp_t, float64_t].iterator it = d.my_map.begin()
cdef cpp_map[intp_t, float64_t].iterator end = d.my_map.end()
cdef intp_t min_key = -1
cdef float64_t min_value = np.inf
while it != end:
if deref(it).second < min_value:
min_value = deref(it).second
min_key = deref(it).first
inc(it)
return min_key, min_value

View File

@@ -0,0 +1,14 @@
# Heap routines, used in various Cython implementations.
from cython cimport floating
from ._typedefs cimport intp_t
cdef int heap_push(
floating* values,
intp_t* indices,
intp_t size,
floating val,
intp_t val_idx,
) noexcept nogil

View File

@@ -0,0 +1,85 @@
from cython cimport floating
from ._typedefs cimport intp_t
cdef inline int heap_push(
floating* values,
intp_t* indices,
intp_t size,
floating val,
intp_t val_idx,
) noexcept nogil:
"""Push a tuple (val, val_idx) onto a fixed-size max-heap.
The max-heap is represented as a Structure of Arrays where:
- values is the array containing the data to construct the heap with
- indices is the array containing the indices (meta-data) of each value
Notes
-----
Arrays are manipulated via a pointer to there first element and their size
as to ease the processing of dynamically allocated buffers.
For instance, in pseudo-code:
values = [1.2, 0.4, 0.1],
indices = [42, 1, 5],
heap_push(
values=values,
indices=indices,
size=3,
val=0.2,
val_idx=4,
)
will modify values and indices inplace, giving at the end of the call:
values == [0.4, 0.2, 0.1]
indices == [1, 4, 5]
"""
cdef:
intp_t current_idx, left_child_idx, right_child_idx, swap_idx
# Check if val should be in heap
if val >= values[0]:
return 0
# Insert val at position zero
values[0] = val
indices[0] = val_idx
# Descend the heap, swapping values until the max heap criterion is met
current_idx = 0
while True:
left_child_idx = 2 * current_idx + 1
right_child_idx = left_child_idx + 1
if left_child_idx >= size:
break
elif right_child_idx >= size:
if values[left_child_idx] > val:
swap_idx = left_child_idx
else:
break
elif values[left_child_idx] >= values[right_child_idx]:
if val < values[left_child_idx]:
swap_idx = left_child_idx
else:
break
else:
if val < values[right_child_idx]:
swap_idx = right_child_idx
else:
break
values[current_idx] = values[swap_idx]
indices[current_idx] = indices[swap_idx]
current_idx = swap_idx
values[current_idx] = val
indices[current_idx] = val_idx
return 0

View File

@@ -0,0 +1,755 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numbers
import sys
import warnings
from collections import UserList
from itertools import compress, islice
import numpy as np
from scipy.sparse import issparse
from sklearn.utils.fixes import PYARROW_VERSION_BELOW_17
from ._array_api import _is_numpy_namespace, get_namespace
from ._param_validation import Interval, validate_params
from .extmath import _approximate_mode
from .validation import (
_check_sample_weight,
_is_arraylike_not_scalar,
_is_pandas_df,
_is_polars_df_or_series,
_is_pyarrow_data,
_use_interchange_protocol,
check_array,
check_consistent_length,
check_random_state,
)
def _array_indexing(array, key, key_dtype, axis):
"""Index an array or scipy.sparse consistently across NumPy version."""
xp, is_array_api = get_namespace(array)
if is_array_api:
return xp.take(array, key, axis=axis)
if issparse(array) and key_dtype == "bool":
key = np.asarray(key)
if isinstance(key, tuple):
key = list(key)
return array[key, ...] if axis == 0 else array[:, key]
def _pandas_indexing(X, key, key_dtype, axis):
"""Index a pandas dataframe or a series."""
if _is_arraylike_not_scalar(key):
key = np.asarray(key)
if key_dtype == "int" and not (isinstance(key, slice) or np.isscalar(key)):
# using take() instead of iloc[] ensures the return value is a "proper"
# copy that will not raise SettingWithCopyWarning
return X.take(key, axis=axis)
else:
# check whether we should index with loc or iloc
indexer = X.iloc if key_dtype == "int" else X.loc
return indexer[:, key] if axis else indexer[key]
def _list_indexing(X, key, key_dtype):
"""Index a Python list."""
if np.isscalar(key) or isinstance(key, slice):
# key is a slice or a scalar
return X[key]
if key_dtype == "bool":
# key is a boolean array-like
return list(compress(X, key))
# key is a integer array-like of key
return [X[idx] for idx in key]
def _polars_indexing(X, key, key_dtype, axis):
"""Index a polars dataframe or series."""
# Polars behavior is more consistent with lists
if isinstance(key, np.ndarray):
# Convert each element of the array to a Python scalar
key = key.tolist()
elif not (np.isscalar(key) or isinstance(key, slice)):
key = list(key)
if axis == 1:
# Here we are certain to have a polars DataFrame; which can be indexed with
# integer and string scalar, and list of integer, string and boolean
return X[:, key]
if key_dtype == "bool":
# Boolean mask can be indexed in the same way for Series and DataFrame (axis=0)
return X.filter(key)
# Integer scalar and list of integer can be indexed in the same way for Series and
# DataFrame (axis=0)
X_indexed = X[key]
if np.isscalar(key) and len(X.shape) == 2:
# `X_indexed` is a DataFrame with a single row; we return a Series to be
# consistent with pandas
pl = sys.modules["polars"]
return pl.Series(X_indexed.row(0))
return X_indexed
def _pyarrow_indexing(X, key, key_dtype, axis):
"""Index a pyarrow data."""
scalar_key = np.isscalar(key)
if isinstance(key, slice):
if isinstance(key.stop, str):
start = X.column_names.index(key.start)
stop = X.column_names.index(key.stop) + 1
else:
start = 0 if not key.start else key.start
stop = key.stop
step = 1 if not key.step else key.step
key = list(range(start, stop, step))
if axis == 1:
# Here we are certain that X is a pyarrow Table or RecordBatch.
if key_dtype == "int" and not isinstance(key, list):
# pyarrow's X.select behavior is more consistent with integer lists.
key = np.asarray(key).tolist()
if key_dtype == "bool":
key = np.asarray(key).nonzero()[0].tolist()
if scalar_key:
return X.column(key)
return X.select(key)
# axis == 0 from here on
if scalar_key:
if hasattr(X, "shape"):
# X is a Table or RecordBatch
key = [key]
else:
return X[key].as_py()
elif not isinstance(key, list):
key = np.asarray(key)
if key_dtype == "bool":
# TODO(pyarrow): remove version checking and following if-branch when
# pyarrow==17.0.0 is the minimal version, see pyarrow issue
# https://github.com/apache/arrow/issues/42013 for more info
if PYARROW_VERSION_BELOW_17:
import pyarrow
if not isinstance(key, pyarrow.BooleanArray):
key = pyarrow.array(key, type=pyarrow.bool_())
X_indexed = X.filter(key)
else:
X_indexed = X.take(key)
if scalar_key and len(getattr(X, "shape", [0])) == 2:
# X_indexed is a dataframe-like with a single row; we return a Series to be
# consistent with pandas
pa = sys.modules["pyarrow"]
return pa.array(X_indexed.to_pylist()[0].values())
return X_indexed
def _determine_key_type(key, accept_slice=True):
"""Determine the data type of key.
Parameters
----------
key : scalar, slice or array-like
The key from which we want to infer the data type.
accept_slice : bool, default=True
Whether or not to raise an error if the key is a slice.
Returns
-------
dtype : {'int', 'str', 'bool', None}
Returns the data type of key.
"""
err_msg = (
"No valid specification of the columns. Only a scalar, list or "
"slice of all integers or all strings, or boolean mask is "
"allowed"
)
dtype_to_str = {int: "int", str: "str", bool: "bool", np.bool_: "bool"}
array_dtype_to_str = {
"i": "int",
"u": "int",
"b": "bool",
"O": "str",
"U": "str",
"S": "str",
}
if key is None:
return None
if isinstance(key, tuple(dtype_to_str.keys())):
try:
return dtype_to_str[type(key)]
except KeyError:
raise ValueError(err_msg)
if isinstance(key, slice):
if not accept_slice:
raise TypeError(
"Only array-like or scalar are supported. A Python slice was given."
)
if key.start is None and key.stop is None:
return None
key_start_type = _determine_key_type(key.start)
key_stop_type = _determine_key_type(key.stop)
if key_start_type is not None and key_stop_type is not None:
if key_start_type != key_stop_type:
raise ValueError(err_msg)
if key_start_type is not None:
return key_start_type
return key_stop_type
# TODO(1.9) remove UserList when the force_int_remainder_cols param
# of ColumnTransformer is removed
if isinstance(key, (list, tuple, UserList)):
unique_key = set(key)
key_type = {_determine_key_type(elt) for elt in unique_key}
if not key_type:
return None
if len(key_type) != 1:
raise ValueError(err_msg)
return key_type.pop()
if hasattr(key, "dtype"):
xp, is_array_api = get_namespace(key)
# NumPy arrays are special-cased in their own branch because the Array API
# cannot handle object/string-based dtypes that are often used to index
# columns of dataframes by names.
if is_array_api and not _is_numpy_namespace(xp):
if xp.isdtype(key.dtype, "bool"):
return "bool"
elif xp.isdtype(key.dtype, "integral"):
return "int"
else:
raise ValueError(err_msg)
else:
try:
return array_dtype_to_str[key.dtype.kind]
except KeyError:
raise ValueError(err_msg)
raise ValueError(err_msg)
def _safe_indexing(X, indices, *, axis=0):
"""Return rows, items or columns of X using indices.
.. warning::
This utility is documented, but **private**. This means that
backward compatibility might be broken without any deprecation
cycle.
Parameters
----------
X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series
Data from which to sample rows, items or columns. `list` are only
supported when `axis=0`.
indices : bool, int, str, slice, array-like
- If `axis=0`, boolean and integer array-like, integer slice,
and scalar integer are supported.
- If `axis=1`:
- to select a single column, `indices` can be of `int` type for
all `X` types and `str` only for dataframe. The selected subset
will be 1D, unless `X` is a sparse matrix in which case it will
be 2D.
- to select multiples columns, `indices` can be one of the
following: `list`, `array`, `slice`. The type used in
these containers can be one of the following: `int`, 'bool' and
`str`. However, `str` is only supported when `X` is a dataframe.
The selected subset will be 2D.
axis : int, default=0
The axis along which `X` will be subsampled. `axis=0` will select
rows while `axis=1` will select columns.
Returns
-------
subset
Subset of X on axis 0 or 1.
Notes
-----
CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are
not supported.
Examples
--------
>>> import numpy as np
>>> from sklearn.utils import _safe_indexing
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
>>> _safe_indexing(data, 0, axis=0) # select the first row
array([1, 2])
>>> _safe_indexing(data, 0, axis=1) # select the first column
array([1, 3, 5])
"""
if indices is None:
return X
if axis not in (0, 1):
raise ValueError(
"'axis' should be either 0 (to index rows) or 1 (to index "
" column). Got {} instead.".format(axis)
)
indices_dtype = _determine_key_type(indices)
if axis == 0 and indices_dtype == "str":
raise ValueError("String indexing is not supported with 'axis=0'")
if axis == 1 and isinstance(X, list):
raise ValueError("axis=1 is not supported for lists")
if axis == 1 and (ndim := len(getattr(X, "shape", [0]))) != 2:
raise ValueError(
"'X' should be a 2D NumPy array, 2D sparse matrix or "
"dataframe when indexing the columns (i.e. 'axis=1'). "
f"Got {type(X)} instead with {ndim} dimension(s)."
)
if (
axis == 1
and indices_dtype == "str"
and not (_is_pandas_df(X) or _use_interchange_protocol(X))
):
raise ValueError(
"Specifying the columns using strings is only supported for dataframes."
)
if hasattr(X, "iloc"):
# TODO: we should probably use _is_pandas_df_or_series(X) instead but:
# 1) Currently, it (probably) works for dataframes compliant to pandas' API.
# 2) Updating would require updating some tests such as
# test_train_test_split_mock_pandas.
return _pandas_indexing(X, indices, indices_dtype, axis=axis)
elif _is_polars_df_or_series(X):
return _polars_indexing(X, indices, indices_dtype, axis=axis)
elif _is_pyarrow_data(X):
return _pyarrow_indexing(X, indices, indices_dtype, axis=axis)
elif _use_interchange_protocol(X): # pragma: no cover
# Once the dataframe X is converted into its dataframe interchange protocol
# version by calling X.__dataframe__(), it becomes very hard to turn it back
# into its original type, e.g., a pyarrow.Table, see
# https://github.com/data-apis/dataframe-api/issues/85.
raise warnings.warn(
message="A data object with support for the dataframe interchange protocol"
"was passed, but scikit-learn does currently not know how to handle this "
"kind of data. Some array/list indexing will be tried.",
category=UserWarning,
)
if hasattr(X, "shape"):
return _array_indexing(X, indices, indices_dtype, axis=axis)
else:
return _list_indexing(X, indices, indices_dtype)
def _safe_assign(X, values, *, row_indexer=None, column_indexer=None):
"""Safe assignment to a numpy array, sparse matrix, or pandas dataframe.
Parameters
----------
X : {ndarray, sparse-matrix, dataframe}
Array to be modified. It is expected to be 2-dimensional.
values : ndarray
The values to be assigned to `X`.
row_indexer : array-like, dtype={int, bool}, default=None
A 1-dimensional array to select the rows of interest. If `None`, all
rows are selected.
column_indexer : array-like, dtype={int, bool}, default=None
A 1-dimensional array to select the columns of interest. If `None`, all
columns are selected.
"""
row_indexer = slice(None, None, None) if row_indexer is None else row_indexer
column_indexer = (
slice(None, None, None) if column_indexer is None else column_indexer
)
if hasattr(X, "iloc"): # pandas dataframe
with warnings.catch_warnings():
# pandas >= 1.5 raises a warning when using iloc to set values in a column
# that does not have the same type as the column being set. It happens
# for instance when setting a categorical column with a string.
# In the future the behavior won't change and the warning should disappear.
# TODO(1.3): check if the warning is still raised or remove the filter.
warnings.simplefilter("ignore", FutureWarning)
X.iloc[row_indexer, column_indexer] = values
else: # numpy array or sparse matrix
X[row_indexer, column_indexer] = values
def _get_column_indices_for_bool_or_int(key, n_columns):
# Convert key into list of positive integer indexes
try:
idx = _safe_indexing(np.arange(n_columns), key)
except IndexError as e:
raise ValueError(
f"all features must be in [0, {n_columns - 1}] or [-{n_columns}, 0]"
) from e
return np.atleast_1d(idx).tolist()
def _get_column_indices(X, key):
"""Get feature column indices for input data X and key.
For accepted values of `key`, see the docstring of
:func:`_safe_indexing`.
"""
key_dtype = _determine_key_type(key)
if _use_interchange_protocol(X):
return _get_column_indices_interchange(X.__dataframe__(), key, key_dtype)
n_columns = X.shape[1]
if isinstance(key, (list, tuple)) and not key:
# we get an empty list
return []
elif key_dtype in ("bool", "int"):
return _get_column_indices_for_bool_or_int(key, n_columns)
else:
try:
all_columns = X.columns
except AttributeError:
raise ValueError(
"Specifying the columns using strings is only supported for dataframes."
)
if isinstance(key, str):
columns = [key]
elif isinstance(key, slice):
start, stop = key.start, key.stop
if start is not None:
start = all_columns.get_loc(start)
if stop is not None:
# pandas indexing with strings is endpoint included
stop = all_columns.get_loc(stop) + 1
else:
stop = n_columns + 1
return list(islice(range(n_columns), start, stop))
else:
columns = list(key)
try:
column_indices = []
for col in columns:
col_idx = all_columns.get_loc(col)
if not isinstance(col_idx, numbers.Integral):
raise ValueError(
f"Selected columns, {columns}, are not unique in dataframe"
)
column_indices.append(col_idx)
except KeyError as e:
raise ValueError("A given column is not a column of the dataframe") from e
return column_indices
def _get_column_indices_interchange(X_interchange, key, key_dtype):
"""Same as _get_column_indices but for X with __dataframe__ protocol."""
n_columns = X_interchange.num_columns()
if isinstance(key, (list, tuple)) and not key:
# we get an empty list
return []
elif key_dtype in ("bool", "int"):
return _get_column_indices_for_bool_or_int(key, n_columns)
else:
column_names = list(X_interchange.column_names())
if isinstance(key, slice):
if key.step not in [1, None]:
raise NotImplementedError("key.step must be 1 or None")
start, stop = key.start, key.stop
if start is not None:
start = column_names.index(start)
if stop is not None:
stop = column_names.index(stop) + 1
else:
stop = n_columns + 1
return list(islice(range(n_columns), start, stop))
selected_columns = [key] if np.isscalar(key) else key
try:
return [column_names.index(col) for col in selected_columns]
except ValueError as e:
raise ValueError("A given column is not a column of the dataframe") from e
@validate_params(
{
"replace": ["boolean"],
"n_samples": [Interval(numbers.Integral, 1, None, closed="left"), None],
"random_state": ["random_state"],
"stratify": ["array-like", "sparse matrix", None],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def resample(
*arrays,
replace=True,
n_samples=None,
random_state=None,
stratify=None,
sample_weight=None,
):
"""Resample arrays or sparse matrices in a consistent way.
The default strategy implements one step of the bootstrapping
procedure.
Parameters
----------
*arrays : sequence of array-like of shape (n_samples,) or \
(n_samples, n_outputs)
Indexable data-structures can be arrays, lists, dataframes or scipy
sparse matrices with consistent first dimension.
replace : bool, default=True
Implements resampling with replacement. It must be set to True
whenever sampling with non-uniform weights: a few data points with very large
weights are expected to be sampled several times with probability to preserve
the distribution induced by the weights. If False, this will implement
(sliced) random permutations.
n_samples : int, default=None
Number of samples to generate. If left to None this is
automatically set to the first dimension of the arrays.
If replace is False it should not be larger than the length of
arrays.
random_state : int, RandomState instance or None, default=None
Determines random number generation for shuffling
the data.
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.
stratify : {array-like, sparse matrix} of shape (n_samples,) or \
(n_samples, n_outputs), default=None
If not None, data is split in a stratified fashion, using this as
the class labels.
sample_weight : array-like of shape (n_samples,), default=None
Contains weight values to be associated with each sample. Values are
normalized to sum to one and interpreted as probability for sampling
each data point.
.. versionadded:: 1.7
Returns
-------
resampled_arrays : sequence of array-like of shape (n_samples,) or \
(n_samples, n_outputs)
Sequence of resampled copies of the collections. The original arrays
are not impacted.
See Also
--------
shuffle : Shuffle arrays or sparse matrices in a consistent way.
Examples
--------
It is possible to mix sparse and dense arrays in the same run::
>>> import numpy as np
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
>>> y = np.array([0, 1, 2])
>>> from scipy.sparse import coo_matrix
>>> X_sparse = coo_matrix(X)
>>> from sklearn.utils import resample
>>> X, X_sparse, y = resample(X, X_sparse, y, random_state=0)
>>> X
array([[1., 0.],
[2., 1.],
[1., 0.]])
>>> X_sparse
<Compressed Sparse Row sparse matrix of dtype 'float64'
with 4 stored elements and shape (3, 2)>
>>> X_sparse.toarray()
array([[1., 0.],
[2., 1.],
[1., 0.]])
>>> y
array([0, 1, 0])
>>> resample(y, n_samples=2, random_state=0)
array([0, 1])
Example using stratification::
>>> y = [0, 0, 1, 1, 1, 1, 1, 1, 1]
>>> resample(y, n_samples=5, replace=False, stratify=y,
... random_state=0)
[1, 1, 1, 0, 1]
"""
max_n_samples = n_samples
random_state = check_random_state(random_state)
if len(arrays) == 0:
return None
first = arrays[0]
n_samples = first.shape[0] if hasattr(first, "shape") else len(first)
if max_n_samples is None:
max_n_samples = n_samples
elif (max_n_samples > n_samples) and (not replace):
raise ValueError(
"Cannot sample %d out of arrays with dim %d when replace is False"
% (max_n_samples, n_samples)
)
check_consistent_length(*arrays)
if sample_weight is not None and not replace:
raise NotImplementedError(
"Resampling with sample_weight is only implemented for replace=True."
)
if sample_weight is not None and stratify is not None:
raise NotImplementedError(
"Resampling with sample_weight is only implemented for stratify=None."
)
if stratify is None:
if replace:
if sample_weight is not None:
sample_weight = _check_sample_weight(
sample_weight, first, dtype=np.float64
)
p = sample_weight / sample_weight.sum()
else:
p = None
indices = random_state.choice(
n_samples,
size=max_n_samples,
p=p,
replace=True,
)
else:
indices = np.arange(n_samples)
random_state.shuffle(indices)
indices = indices[:max_n_samples]
else:
# Code adapted from StratifiedShuffleSplit()
y = check_array(stratify, ensure_2d=False, dtype=None)
if y.ndim == 2:
# for multi-label y, map each distinct row to a string repr
# using join because str(row) uses an ellipsis if len(row) > 1000
y = np.array([" ".join(row.astype("str")) for row in y])
classes, y_indices = np.unique(y, return_inverse=True)
n_classes = classes.shape[0]
class_counts = np.bincount(y_indices)
# Find the sorted list of instances for each class:
# (np.unique above performs a sort, so code is O(n logn) already)
class_indices = np.split(
np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
)
n_i = _approximate_mode(class_counts, max_n_samples, random_state)
indices = []
for i in range(n_classes):
indices_i = random_state.choice(class_indices[i], n_i[i], replace=replace)
indices.extend(indices_i)
indices = random_state.permutation(indices)
# convert sparse matrices to CSR for row-based indexing
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
resampled_arrays = [_safe_indexing(a, indices) for a in arrays]
if len(resampled_arrays) == 1:
# syntactic sugar for the unit argument case
return resampled_arrays[0]
else:
return resampled_arrays
def shuffle(*arrays, random_state=None, n_samples=None):
"""Shuffle arrays or sparse matrices in a consistent way.
This is a convenience alias to ``resample(*arrays, replace=False)`` to do
random permutations of the collections.
Parameters
----------
*arrays : sequence of indexable data-structures
Indexable data-structures can be arrays, lists, dataframes or scipy
sparse matrices with consistent first dimension.
random_state : int, RandomState instance or None, default=None
Determines random number generation for shuffling
the data.
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.
n_samples : int, default=None
Number of samples to generate. If left to None this is
automatically set to the first dimension of the arrays. It should
not be larger than the length of arrays.
Returns
-------
shuffled_arrays : sequence of indexable data-structures
Sequence of shuffled copies of the collections. The original arrays
are not impacted.
See Also
--------
resample : Resample arrays or sparse matrices in a consistent way.
Examples
--------
It is possible to mix sparse and dense arrays in the same run::
>>> import numpy as np
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
>>> y = np.array([0, 1, 2])
>>> from scipy.sparse import coo_matrix
>>> X_sparse = coo_matrix(X)
>>> from sklearn.utils import shuffle
>>> X, X_sparse, y = shuffle(X, X_sparse, y, random_state=0)
>>> X
array([[0., 0.],
[2., 1.],
[1., 0.]])
>>> X_sparse
<Compressed Sparse Row sparse matrix of dtype 'float64'
with 3 stored elements and shape (3, 2)>
>>> X_sparse.toarray()
array([[0., 0.],
[2., 1.],
[1., 0.]])
>>> y
array([2, 1, 0])
>>> shuffle(y, n_samples=2, random_state=0)
array([0, 1])
"""
return resample(
*arrays, replace=False, n_samples=n_samples, random_state=random_state
)

View File

@@ -0,0 +1,51 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from libc.math cimport isnan, isinf
from cython cimport floating
cpdef enum FiniteStatus:
all_finite = 0
has_nan = 1
has_infinite = 2
def cy_isfinite(floating[::1] a, bint allow_nan=False):
cdef FiniteStatus result
with nogil:
result = _isfinite(a, allow_nan)
return result
cdef inline FiniteStatus _isfinite(floating[::1] a, bint allow_nan) noexcept nogil:
cdef floating* a_ptr = &a[0]
cdef Py_ssize_t length = len(a)
if allow_nan:
return _isfinite_allow_nan(a_ptr, length)
else:
return _isfinite_disable_nan(a_ptr, length)
cdef inline FiniteStatus _isfinite_allow_nan(floating* a_ptr,
Py_ssize_t length) noexcept nogil:
cdef Py_ssize_t i
cdef floating v
for i in range(length):
v = a_ptr[i]
if isinf(v):
return FiniteStatus.has_infinite
return FiniteStatus.all_finite
cdef inline FiniteStatus _isfinite_disable_nan(floating* a_ptr,
Py_ssize_t length) noexcept nogil:
cdef Py_ssize_t i
cdef floating v
for i in range(length):
v = a_ptr[i]
if isnan(v):
return FiniteStatus.has_nan
elif isinf(v):
return FiniteStatus.has_infinite
return FiniteStatus.all_finite

View File

@@ -0,0 +1,181 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from contextlib import suppress
import numpy as np
from scipy import sparse as sp
from ._missing import is_scalar_nan
from ._param_validation import validate_params
from .fixes import _object_dtype_isnan
def _get_dense_mask(X, value_to_mask):
with suppress(ImportError, AttributeError):
# We also suppress `AttributeError` because older versions of pandas do
# not have `NA`.
import pandas
if value_to_mask is pandas.NA:
return pandas.isna(X)
if is_scalar_nan(value_to_mask):
if X.dtype.kind == "f":
Xt = np.isnan(X)
elif X.dtype.kind in ("i", "u"):
# can't have NaNs in integer array.
Xt = np.zeros(X.shape, dtype=bool)
else:
# np.isnan does not work on object dtypes.
Xt = _object_dtype_isnan(X)
else:
Xt = X == value_to_mask
return Xt
def _get_mask(X, value_to_mask):
"""Compute the boolean mask X == value_to_mask.
Parameters
----------
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
Input data, where ``n_samples`` is the number of samples and
``n_features`` is the number of features.
value_to_mask : {int, float}
The value which is to be masked in X.
Returns
-------
X_mask : {ndarray, sparse matrix} of shape (n_samples, n_features)
Missing mask.
"""
if not sp.issparse(X):
# For all cases apart of a sparse input where we need to reconstruct
# a sparse output
return _get_dense_mask(X, value_to_mask)
Xt = _get_dense_mask(X.data, value_to_mask)
sparse_constructor = sp.csr_matrix if X.format == "csr" else sp.csc_matrix
Xt_sparse = sparse_constructor(
(Xt, X.indices.copy(), X.indptr.copy()), shape=X.shape, dtype=bool
)
return Xt_sparse
@validate_params(
{
"X": ["array-like", "sparse matrix"],
"mask": ["array-like"],
},
prefer_skip_nested_validation=True,
)
def safe_mask(X, mask):
"""Return a mask which is safe to use on X.
Parameters
----------
X : {array-like, sparse matrix}
Data on which to apply mask.
mask : array-like
Mask to be used on X.
Returns
-------
mask : ndarray
Array that is safe to use on X.
Examples
--------
>>> from sklearn.utils import safe_mask
>>> from scipy.sparse import csr_matrix
>>> data = csr_matrix([[1], [2], [3], [4], [5]])
>>> condition = [False, True, True, False, True]
>>> mask = safe_mask(data, condition)
>>> data[mask].toarray()
array([[2],
[3],
[5]])
"""
mask = np.asarray(mask)
if np.issubdtype(mask.dtype, np.signedinteger):
return mask
if hasattr(X, "toarray"):
ind = np.arange(mask.shape[0])
mask = ind[mask]
return mask
def axis0_safe_slice(X, mask, len_mask):
"""Return a mask which is safer to use on X than safe_mask.
This mask is safer than safe_mask since it returns an
empty array, when a sparse matrix is sliced with a boolean mask
with all False, instead of raising an unhelpful error in older
versions of SciPy.
See: https://github.com/scipy/scipy/issues/5361
Also note that we can avoid doing the dot product by checking if
the len_mask is not zero in _huber_loss_and_gradient but this
is not going to be the bottleneck, since the number of outliers
and non_outliers are typically non-zero and it makes the code
tougher to follow.
Parameters
----------
X : {array-like, sparse matrix}
Data on which to apply mask.
mask : ndarray
Mask to be used on X.
len_mask : int
The length of the mask.
Returns
-------
mask : ndarray
Array that is safe to use on X.
"""
if len_mask != 0:
return X[safe_mask(X, mask), :]
return np.zeros(shape=(0, X.shape[1]))
def indices_to_mask(indices, mask_length):
"""Convert list of indices to boolean mask.
Parameters
----------
indices : list-like
List of integers treated as indices.
mask_length : int
Length of boolean mask to be generated.
This parameter must be greater than max(indices).
Returns
-------
mask : 1d boolean nd-array
Boolean array that is True where indices are present, else False.
Examples
--------
>>> from sklearn.utils._mask import indices_to_mask
>>> indices = [1, 2 , 3, 4]
>>> indices_to_mask(indices, 5)
array([False, True, True, True, True])
"""
if mask_length <= np.max(indices):
raise ValueError("mask_length must be greater than max(indices)")
mask = np.zeros(mask_length, dtype=bool)
mask[indices] = True
return mask

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import math
import numbers
from contextlib import suppress
def is_scalar_nan(x):
"""Test if x is NaN.
This function is meant to overcome the issue that np.isnan does not allow
non-numerical types as input, and that np.nan is not float('nan').
Parameters
----------
x : any type
Any scalar value.
Returns
-------
bool
Returns true if x is NaN, and false otherwise.
Examples
--------
>>> import numpy as np
>>> from sklearn.utils._missing import is_scalar_nan
>>> is_scalar_nan(np.nan)
True
>>> is_scalar_nan(float("nan"))
True
>>> is_scalar_nan(None)
False
>>> is_scalar_nan("")
False
>>> is_scalar_nan([np.nan])
False
"""
return (
not isinstance(x, numbers.Integral)
and isinstance(x, numbers.Real)
and math.isnan(x)
)
def is_pandas_na(x):
"""Test if x is pandas.NA.
We intentionally do not use this function to return `True` for `pd.NA` in
`is_scalar_nan`, because estimators that support `pd.NA` are the exception
rather than the rule at the moment. When `pd.NA` is more universally
supported, we may reconsider this decision.
Parameters
----------
x : any type
Returns
-------
boolean
"""
with suppress(ImportError):
from pandas import NA
return x is NA
return False

View File

@@ -0,0 +1,419 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
from ..base import BaseEstimator, ClassifierMixin
from ..utils._metadata_requests import RequestMethod
from .metaestimators import available_if
from .validation import (
_check_sample_weight,
_num_samples,
check_array,
check_is_fitted,
check_random_state,
)
class ArraySlicingWrapper:
"""
Parameters
----------
array
"""
def __init__(self, array):
self.array = array
def __getitem__(self, aslice):
return MockDataFrame(self.array[aslice])
class MockDataFrame:
"""
Parameters
----------
array
"""
# have shape and length but don't support indexing.
def __init__(self, array):
self.array = array
self.values = array
self.shape = array.shape
self.ndim = array.ndim
# ugly hack to make iloc work.
self.iloc = ArraySlicingWrapper(array)
def __len__(self):
return len(self.array)
def __array__(self, dtype=None):
# Pandas data frames also are array-like: we want to make sure that
# input validation in cross-validation does not try to call that
# method.
return self.array
def __eq__(self, other):
return MockDataFrame(self.array == other.array)
def __ne__(self, other):
return not self == other
def take(self, indices, axis=0):
return MockDataFrame(self.array.take(indices, axis=axis))
class CheckingClassifier(ClassifierMixin, BaseEstimator):
"""Dummy classifier to test pipelining and meta-estimators.
Checks some property of `X` and `y`in fit / predict.
This allows testing whether pipelines / cross-validation or metaestimators
changed the input.
Can also be used to check if `fit_params` are passed correctly, and
to force a certain score to be returned.
Parameters
----------
check_y, check_X : callable, default=None
The callable used to validate `X` and `y`. These callable should return
a bool where `False` will trigger an `AssertionError`. If `None`, the
data is not validated. Default is `None`.
check_y_params, check_X_params : dict, default=None
The optional parameters to pass to `check_X` and `check_y`. If `None`,
then no parameters are passed in.
methods_to_check : "all" or list of str, default="all"
The methods in which the checks should be applied. By default,
all checks will be done on all methods (`fit`, `predict`,
`predict_proba`, `decision_function` and `score`).
foo_param : int, default=0
A `foo` param. When `foo > 1`, the output of :meth:`score` will be 1
otherwise it is 0.
expected_sample_weight : bool, default=False
Whether to check if a valid `sample_weight` was passed to `fit`.
expected_fit_params : list of str, default=None
A list of the expected parameters given when calling `fit`.
Attributes
----------
classes_ : int
The classes seen during `fit`.
n_features_in_ : int
The number of features seen during `fit`.
Examples
--------
>>> from sklearn.utils._mocking import CheckingClassifier
This helper allow to assert to specificities regarding `X` or `y`. In this
case we expect `check_X` or `check_y` to return a boolean.
>>> from sklearn.datasets import load_iris
>>> X, y = load_iris(return_X_y=True)
>>> clf = CheckingClassifier(check_X=lambda x: x.shape == (150, 4))
>>> clf.fit(X, y)
CheckingClassifier(...)
We can also provide a check which might raise an error. In this case, we
expect `check_X` to return `X` and `check_y` to return `y`.
>>> from sklearn.utils import check_array
>>> clf = CheckingClassifier(check_X=check_array)
>>> clf.fit(X, y)
CheckingClassifier(...)
"""
def __init__(
self,
*,
check_y=None,
check_y_params=None,
check_X=None,
check_X_params=None,
methods_to_check="all",
foo_param=0,
expected_sample_weight=None,
expected_fit_params=None,
random_state=None,
):
self.check_y = check_y
self.check_y_params = check_y_params
self.check_X = check_X
self.check_X_params = check_X_params
self.methods_to_check = methods_to_check
self.foo_param = foo_param
self.expected_sample_weight = expected_sample_weight
self.expected_fit_params = expected_fit_params
self.random_state = random_state
def _check_X_y(self, X, y=None, should_be_fitted=True):
"""Validate X and y and make extra check.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The data set.
`X` is checked only if `check_X` is not `None` (default is None).
y : array-like of shape (n_samples), default=None
The corresponding target, by default `None`.
`y` is checked only if `check_y` is not `None` (default is None).
should_be_fitted : bool, default=True
Whether or not the classifier should be already fitted.
By default True.
Returns
-------
X, y
"""
if should_be_fitted:
check_is_fitted(self)
if self.check_X is not None:
params = {} if self.check_X_params is None else self.check_X_params
checked_X = self.check_X(X, **params)
if isinstance(checked_X, (bool, np.bool_)):
assert checked_X
else:
X = checked_X
if y is not None and self.check_y is not None:
params = {} if self.check_y_params is None else self.check_y_params
checked_y = self.check_y(y, **params)
if isinstance(checked_y, (bool, np.bool_)):
assert checked_y
else:
y = checked_y
return X, y
def fit(self, X, y, sample_weight=None, **fit_params):
"""Fit classifier.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
y : array-like of shape (n_samples, n_outputs) or (n_samples,), \
default=None
Target relative to X for classification or regression;
None for unsupervised learning.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of the estimator
Returns
-------
self
"""
assert _num_samples(X) == _num_samples(y)
if self.methods_to_check == "all" or "fit" in self.methods_to_check:
X, y = self._check_X_y(X, y, should_be_fitted=False)
self.n_features_in_ = np.shape(X)[1]
self.classes_ = np.unique(check_array(y, ensure_2d=False, allow_nd=True))
if self.expected_fit_params:
missing = set(self.expected_fit_params) - set(fit_params)
if missing:
raise AssertionError(
f"Expected fit parameter(s) {list(missing)} not seen."
)
for key, value in fit_params.items():
if _num_samples(value) != _num_samples(X):
raise AssertionError(
f"Fit parameter {key} has length {_num_samples(value)}"
f"; expected {_num_samples(X)}."
)
if self.expected_sample_weight:
if sample_weight is None:
raise AssertionError("Expected sample_weight to be passed")
_check_sample_weight(sample_weight, X)
return self
def predict(self, X):
"""Predict the first class seen in `classes_`.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input data.
Returns
-------
preds : ndarray of shape (n_samples,)
Predictions of the first class seen in `classes_`.
"""
if self.methods_to_check == "all" or "predict" in self.methods_to_check:
X, y = self._check_X_y(X)
rng = check_random_state(self.random_state)
return rng.choice(self.classes_, size=_num_samples(X))
def predict_proba(self, X):
"""Predict probabilities for each class.
Here, the dummy classifier will provide a probability of 1 for the
first class of `classes_` and 0 otherwise.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input data.
Returns
-------
proba : ndarray of shape (n_samples, n_classes)
The probabilities for each sample and class.
"""
if self.methods_to_check == "all" or "predict_proba" in self.methods_to_check:
X, y = self._check_X_y(X)
rng = check_random_state(self.random_state)
proba = rng.randn(_num_samples(X), len(self.classes_))
proba = np.abs(proba, out=proba)
proba /= np.sum(proba, axis=1)[:, np.newaxis]
return proba
def decision_function(self, X):
"""Confidence score.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input data.
Returns
-------
decision : ndarray of shape (n_samples,) if n_classes == 2\
else (n_samples, n_classes)
Confidence score.
"""
if (
self.methods_to_check == "all"
or "decision_function" in self.methods_to_check
):
X, y = self._check_X_y(X)
rng = check_random_state(self.random_state)
if len(self.classes_) == 2:
# for binary classifier, the confidence score is related to
# classes_[1] and therefore should be null.
return rng.randn(_num_samples(X))
else:
return rng.randn(_num_samples(X), len(self.classes_))
def score(self, X=None, Y=None):
"""Fake score.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input data, where `n_samples` is the number of samples and
`n_features` is the number of features.
Y : array-like of shape (n_samples, n_output) or (n_samples,)
Target relative to X for classification or regression;
None for unsupervised learning.
Returns
-------
score : float
Either 0 or 1 depending of `foo_param` (i.e. `foo_param > 1 =>
score=1` otherwise `score=0`).
"""
if self.methods_to_check == "all" or "score" in self.methods_to_check:
self._check_X_y(X, Y)
if self.foo_param > 1:
score = 1.0
else:
score = 0.0
return score
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._skip_test = True
tags.input_tags.two_d_array = False
tags.target_tags.one_d_labels = True
return tags
# Deactivate key validation for CheckingClassifier because we want to be able to
# call fit with arbitrary fit_params and record them. Without this change, we
# would get an error because those arbitrary params are not expected.
CheckingClassifier.set_fit_request = RequestMethod( # type: ignore[assignment,method-assign]
name="fit", keys=[], validate_keys=False
)
class NoSampleWeightWrapper(BaseEstimator):
"""Wrap estimator which will not expose `sample_weight`.
Parameters
----------
est : estimator, default=None
The estimator to wrap.
"""
def __init__(self, est=None):
self.est = est
def fit(self, X, y):
return self.est.fit(X, y)
def predict(self, X):
return self.est.predict(X)
def predict_proba(self, X):
return self.est.predict_proba(X)
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._skip_test = True
return tags
def _check_response(method):
def check(self):
return self.response_methods is not None and method in self.response_methods
return check
class _MockEstimatorOnOffPrediction(BaseEstimator):
"""Estimator for which we can turn on/off the prediction methods.
Parameters
----------
response_methods: list of \
{"predict", "predict_proba", "decision_function"}, default=None
List containing the response implemented by the estimator. When, the
response is in the list, it will return the name of the response method
when called. Otherwise, an `AttributeError` is raised. It allows to
use `getattr` as any conventional estimator. By default, no response
methods are mocked.
"""
def __init__(self, response_methods=None):
self.response_methods = response_methods
def fit(self, X, y):
self.classes_ = np.unique(y)
return self
@available_if(_check_response("predict"))
def predict(self, X):
return "predict"
@available_if(_check_response("predict_proba"))
def predict_proba(self, X):
return "predict_proba"
@available_if(_check_response("decision_function"))
def decision_function(self, X):
return "decision_function"

View File

@@ -0,0 +1,33 @@
# Helpers to safely access OpenMP routines
#
# no-op implementations are provided for the case where OpenMP is not available.
#
# All calls to OpenMP routines should be cimported from this module.
cdef extern from *:
"""
#ifdef _OPENMP
#include <omp.h>
#define SKLEARN_OPENMP_PARALLELISM_ENABLED 1
#else
#define SKLEARN_OPENMP_PARALLELISM_ENABLED 0
#define omp_lock_t int
#define omp_init_lock(l) (void)0
#define omp_destroy_lock(l) (void)0
#define omp_set_lock(l) (void)0
#define omp_unset_lock(l) (void)0
#define omp_get_thread_num() 0
#define omp_get_max_threads() 1
#endif
"""
bint SKLEARN_OPENMP_PARALLELISM_ENABLED
ctypedef struct omp_lock_t:
pass
void omp_init_lock(omp_lock_t*) noexcept nogil
void omp_destroy_lock(omp_lock_t*) noexcept nogil
void omp_set_lock(omp_lock_t*) noexcept nogil
void omp_unset_lock(omp_lock_t*) noexcept nogil
int omp_get_thread_num() noexcept nogil
int omp_get_max_threads() noexcept nogil

View File

@@ -0,0 +1,77 @@
import os
from joblib import cpu_count
# Module level cache for cpu_count as we do not expect this to change during
# the lifecycle of a Python program. This dictionary is keyed by
# only_physical_cores.
_CPU_COUNTS = {}
def _openmp_parallelism_enabled():
"""Determines whether scikit-learn has been built with OpenMP
It allows to retrieve at runtime the information gathered at compile time.
"""
# SKLEARN_OPENMP_PARALLELISM_ENABLED is resolved at compile time and defined
# in _openmp_helpers.pxd as a boolean. This function exposes it to Python.
return SKLEARN_OPENMP_PARALLELISM_ENABLED
cpdef _openmp_effective_n_threads(n_threads=None, only_physical_cores=True):
"""Determine the effective number of threads to be used for OpenMP calls
- For ``n_threads = None``,
- if the ``OMP_NUM_THREADS`` environment variable is set, return
``openmp.omp_get_max_threads()``
- otherwise, return the minimum between ``openmp.omp_get_max_threads()``
and the number of cpus, taking cgroups quotas into account. Cgroups
quotas can typically be set by tools such as Docker.
The result of ``omp_get_max_threads`` can be influenced by environment
variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``.
- For ``n_threads > 0``, return this as the maximal number of threads for
parallel OpenMP calls.
- For ``n_threads < 0``, return the maximal number of threads minus
``|n_threads + 1|``. In particular ``n_threads = -1`` will use as many
threads as there are available cores on the machine.
- Raise a ValueError for ``n_threads = 0``.
Passing the `only_physical_cores=False` flag makes it possible to use extra
threads for SMT/HyperThreading logical cores. It has been empirically
observed that using as many threads as available SMT cores can slightly
improve the performance in some cases, but can severely degrade
performance other times. Therefore it is recommended to use
`only_physical_cores=True` unless an empirical study has been conducted to
assess the impact of SMT on a case-by-case basis (using various input data
shapes, in particular small data shapes).
If scikit-learn is built without OpenMP support, always return 1.
"""
if n_threads == 0:
raise ValueError("n_threads = 0 is invalid")
if not SKLEARN_OPENMP_PARALLELISM_ENABLED:
# OpenMP disabled at build-time => sequential mode
return 1
if os.getenv("OMP_NUM_THREADS"):
# Fall back to user provided number of threads making it possible
# to exceed the number of cpus.
max_n_threads = omp_get_max_threads()
else:
try:
n_cpus = _CPU_COUNTS[only_physical_cores]
except KeyError:
n_cpus = cpu_count(only_physical_cores=only_physical_cores)
_CPU_COUNTS[only_physical_cores] = n_cpus
max_n_threads = min(omp_get_max_threads(), n_cpus)
if n_threads is None:
return max_n_threads
elif n_threads < 0:
return max(1, max_n_threads + n_threads + 1)
return n_threads

View File

@@ -0,0 +1,46 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
def check_matplotlib_support(caller_name):
"""Raise ImportError with detailed error message if mpl is not installed.
Plot utilities like any of the Display's plotting functions should lazily import
matplotlib and call this helper before any computation.
Parameters
----------
caller_name : str
The name of the caller that requires matplotlib.
"""
try:
import matplotlib # noqa: F401
except ImportError as e:
raise ImportError(
"{} requires matplotlib. You can install matplotlib with "
"`pip install matplotlib`".format(caller_name)
) from e
def check_pandas_support(caller_name):
"""Raise ImportError with detailed error message if pandas is not installed.
Plot utilities like :func:`fetch_openml` should lazily import
pandas and call this helper before any computation.
Parameters
----------
caller_name : str
The name of the caller that requires pandas.
Returns
-------
pandas
The pandas package.
"""
try:
import pandas
return pandas
except ImportError as e:
raise ImportError("{} requires pandas.".format(caller_name)) from e

View File

@@ -0,0 +1,910 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import functools
import math
import operator
import re
from abc import ABC, abstractmethod
from collections.abc import Iterable
from inspect import signature
from numbers import Integral, Real
import numpy as np
from scipy.sparse import csr_matrix, issparse
from .._config import config_context, get_config
from .validation import _is_arraylike_not_scalar
class InvalidParameterError(ValueError, TypeError):
"""Custom exception to be raised when the parameter of a class/method/function
does not have a valid type or value.
"""
# Inherits from ValueError and TypeError to keep backward compatibility.
def validate_parameter_constraints(parameter_constraints, params, caller_name):
"""Validate types and values of given parameters.
Parameters
----------
parameter_constraints : dict or {"no_validation"}
If "no_validation", validation is skipped for this parameter.
If a dict, it must be a dictionary `param_name: list of constraints`.
A parameter is valid if it satisfies one of the constraints from the list.
Constraints can be:
- an Interval object, representing a continuous or discrete range of numbers
- the string "array-like"
- the string "sparse matrix"
- the string "random_state"
- callable
- None, meaning that None is a valid value for the parameter
- any type, meaning that any instance of this type is valid
- an Options object, representing a set of elements of a given type
- a StrOptions object, representing a set of strings
- the string "boolean"
- the string "verbose"
- the string "cv_object"
- the string "nan"
- a MissingValues object representing markers for missing values
- a HasMethods object, representing method(s) an object must have
- a Hidden object, representing a constraint not meant to be exposed to the user
params : dict
A dictionary `param_name: param_value`. The parameters to validate against the
constraints.
caller_name : str
The name of the estimator or function or method that called this function.
"""
for param_name, param_val in params.items():
# We allow parameters to not have a constraint so that third party estimators
# can inherit from sklearn estimators without having to necessarily use the
# validation tools.
if param_name not in parameter_constraints:
continue
constraints = parameter_constraints[param_name]
if constraints == "no_validation":
continue
constraints = [make_constraint(constraint) for constraint in constraints]
for constraint in constraints:
if constraint.is_satisfied_by(param_val):
# this constraint is satisfied, no need to check further.
break
else:
# No constraint is satisfied, raise with an informative message.
# Ignore constraints that we don't want to expose in the error message,
# i.e. options that are for internal purpose or not officially supported.
constraints = [
constraint for constraint in constraints if not constraint.hidden
]
if len(constraints) == 1:
constraints_str = f"{constraints[0]}"
else:
constraints_str = (
f"{', '.join([str(c) for c in constraints[:-1]])} or"
f" {constraints[-1]}"
)
raise InvalidParameterError(
f"The {param_name!r} parameter of {caller_name} must be"
f" {constraints_str}. Got {param_val!r} instead."
)
def make_constraint(constraint):
"""Convert the constraint into the appropriate Constraint object.
Parameters
----------
constraint : object
The constraint to convert.
Returns
-------
constraint : instance of _Constraint
The converted constraint.
"""
if isinstance(constraint, str) and constraint == "array-like":
return _ArrayLikes()
if isinstance(constraint, str) and constraint == "sparse matrix":
return _SparseMatrices()
if isinstance(constraint, str) and constraint == "random_state":
return _RandomStates()
if constraint is callable:
return _Callables()
if constraint is None:
return _NoneConstraint()
if isinstance(constraint, type):
return _InstancesOf(constraint)
if isinstance(
constraint, (Interval, StrOptions, Options, HasMethods, MissingValues)
):
return constraint
if isinstance(constraint, str) and constraint == "boolean":
return _Booleans()
if isinstance(constraint, str) and constraint == "verbose":
return _VerboseHelper()
if isinstance(constraint, str) and constraint == "cv_object":
return _CVObjects()
if isinstance(constraint, Hidden):
constraint = make_constraint(constraint.constraint)
constraint.hidden = True
return constraint
if (isinstance(constraint, str) and constraint == "nan") or (
isinstance(constraint, float) and np.isnan(constraint)
):
return _NanConstraint()
raise ValueError(f"Unknown constraint type: {constraint}")
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
"""Decorator to validate types and values of functions and methods.
Parameters
----------
parameter_constraints : dict
A dictionary `param_name: list of constraints`. See the docstring of
`validate_parameter_constraints` for a description of the accepted constraints.
Note that the *args and **kwargs parameters are not validated and must not be
present in the parameter_constraints dictionary.
prefer_skip_nested_validation : bool
If True, the validation of parameters of inner estimators or functions
called by the decorated function will be skipped.
This is useful to avoid validating many times the parameters passed by the
user from the public facing API. It's also useful to avoid validating
parameters that we pass internally to inner functions that are guaranteed to
be valid by the test suite.
It should be set to True for most functions, except for those that receive
non-validated objects as parameters or that are just wrappers around classes
because they only perform a partial validation.
Returns
-------
decorated_function : function or method
The decorated function.
"""
def decorator(func):
# The dict of parameter constraints is set as an attribute of the function
# to make it possible to dynamically introspect the constraints for
# automatic testing.
setattr(func, "_skl_parameter_constraints", parameter_constraints)
@functools.wraps(func)
def wrapper(*args, **kwargs):
global_skip_validation = get_config()["skip_parameter_validation"]
if global_skip_validation:
return func(*args, **kwargs)
func_sig = signature(func)
# Map *args/**kwargs to the function signature
params = func_sig.bind(*args, **kwargs)
params.apply_defaults()
# ignore self/cls and positional/keyword markers
to_ignore = [
p.name
for p in func_sig.parameters.values()
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
]
to_ignore += ["self", "cls"]
params = {k: v for k, v in params.arguments.items() if k not in to_ignore}
validate_parameter_constraints(
parameter_constraints, params, caller_name=func.__qualname__
)
try:
with config_context(
skip_parameter_validation=(
prefer_skip_nested_validation or global_skip_validation
)
):
return func(*args, **kwargs)
except InvalidParameterError as e:
# When the function is just a wrapper around an estimator, we allow
# the function to delegate validation to the estimator, but we replace
# the name of the estimator by the name of the function in the error
# message to avoid confusion.
msg = re.sub(
r"parameter of \w+ must be",
f"parameter of {func.__qualname__} must be",
str(e),
)
raise InvalidParameterError(msg) from e
return wrapper
return decorator
class RealNotInt(Real):
"""A type that represents reals that are not instances of int.
Behaves like float, but also works with values extracted from numpy arrays.
isintance(1, RealNotInt) -> False
isinstance(1.0, RealNotInt) -> True
"""
RealNotInt.register(float)
def _type_name(t):
"""Convert type into human readable string."""
module = t.__module__
qualname = t.__qualname__
if module == "builtins":
return qualname
elif t == Real:
return "float"
elif t == Integral:
return "int"
return f"{module}.{qualname}"
class _Constraint(ABC):
"""Base class for the constraint objects."""
def __init__(self):
self.hidden = False
@abstractmethod
def is_satisfied_by(self, val):
"""Whether or not a value satisfies the constraint.
Parameters
----------
val : object
The value to check.
Returns
-------
is_satisfied : bool
Whether or not the constraint is satisfied by this value.
"""
@abstractmethod
def __str__(self):
"""A human readable representational string of the constraint."""
class _InstancesOf(_Constraint):
"""Constraint representing instances of a given type.
Parameters
----------
type : type
The valid type.
"""
def __init__(self, type):
super().__init__()
self.type = type
def is_satisfied_by(self, val):
return isinstance(val, self.type)
def __str__(self):
return f"an instance of {_type_name(self.type)!r}"
class _NoneConstraint(_Constraint):
"""Constraint representing the None singleton."""
def is_satisfied_by(self, val):
return val is None
def __str__(self):
return "None"
class _NanConstraint(_Constraint):
"""Constraint representing the indicator `np.nan`."""
def is_satisfied_by(self, val):
return (
not isinstance(val, Integral) and isinstance(val, Real) and math.isnan(val)
)
def __str__(self):
return "numpy.nan"
class _PandasNAConstraint(_Constraint):
"""Constraint representing the indicator `pd.NA`."""
def is_satisfied_by(self, val):
try:
import pandas as pd
return isinstance(val, type(pd.NA)) and pd.isna(val)
except ImportError:
return False
def __str__(self):
return "pandas.NA"
class Options(_Constraint):
"""Constraint representing a finite set of instances of a given type.
Parameters
----------
type : type
options : set
The set of valid scalars.
deprecated : set or None, default=None
A subset of the `options` to mark as deprecated in the string
representation of the constraint.
"""
def __init__(self, type, options, *, deprecated=None):
super().__init__()
self.type = type
self.options = options
self.deprecated = deprecated or set()
if self.deprecated - self.options:
raise ValueError("The deprecated options must be a subset of the options.")
def is_satisfied_by(self, val):
return isinstance(val, self.type) and val in self.options
def _mark_if_deprecated(self, option):
"""Add a deprecated mark to an option if needed."""
option_str = f"{option!r}"
if option in self.deprecated:
option_str = f"{option_str} (deprecated)"
return option_str
def __str__(self):
options_str = (
f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}"
)
return f"a {_type_name(self.type)} among {{{options_str}}}"
class StrOptions(Options):
"""Constraint representing a finite set of strings.
Parameters
----------
options : set of str
The set of valid strings.
deprecated : set of str or None, default=None
A subset of the `options` to mark as deprecated in the string
representation of the constraint.
"""
def __init__(self, options, *, deprecated=None):
super().__init__(type=str, options=options, deprecated=deprecated)
class Interval(_Constraint):
"""Constraint representing a typed interval.
Parameters
----------
type : {numbers.Integral, numbers.Real, RealNotInt}
The set of numbers in which to set the interval.
If RealNotInt, only reals that don't have the integer type
are allowed. For example 1.0 is allowed but 1 is not.
left : float or int or None
The left bound of the interval. None means left bound is -∞.
right : float, int or None
The right bound of the interval. None means right bound is +∞.
closed : {"left", "right", "both", "neither"}
Whether the interval is open or closed. Possible choices are:
- `"left"`: the interval is closed on the left and open on the right.
It is equivalent to the interval `[ left, right )`.
- `"right"`: the interval is closed on the right and open on the left.
It is equivalent to the interval `( left, right ]`.
- `"both"`: the interval is closed.
It is equivalent to the interval `[ left, right ]`.
- `"neither"`: the interval is open.
It is equivalent to the interval `( left, right )`.
Notes
-----
Setting a bound to `None` and setting the interval closed is valid. For instance,
strictly speaking, `Interval(Real, 0, None, closed="both")` corresponds to
`[0, +∞) U {+∞}`.
"""
def __init__(self, type, left, right, *, closed):
super().__init__()
self.type = type
self.left = left
self.right = right
self.closed = closed
self._check_params()
def _check_params(self):
if self.type not in (Integral, Real, RealNotInt):
raise ValueError(
"type must be either numbers.Integral, numbers.Real or RealNotInt."
f" Got {self.type} instead."
)
if self.closed not in ("left", "right", "both", "neither"):
raise ValueError(
"closed must be either 'left', 'right', 'both' or 'neither'. "
f"Got {self.closed} instead."
)
if self.type is Integral:
suffix = "for an interval over the integers."
if self.left is not None and not isinstance(self.left, Integral):
raise TypeError(f"Expecting left to be an int {suffix}")
if self.right is not None and not isinstance(self.right, Integral):
raise TypeError(f"Expecting right to be an int {suffix}")
if self.left is None and self.closed in ("left", "both"):
raise ValueError(
f"left can't be None when closed == {self.closed} {suffix}"
)
if self.right is None and self.closed in ("right", "both"):
raise ValueError(
f"right can't be None when closed == {self.closed} {suffix}"
)
else:
if self.left is not None and not isinstance(self.left, Real):
raise TypeError("Expecting left to be a real number.")
if self.right is not None and not isinstance(self.right, Real):
raise TypeError("Expecting right to be a real number.")
if self.right is not None and self.left is not None and self.right <= self.left:
raise ValueError(
f"right can't be less than left. Got left={self.left} and "
f"right={self.right}"
)
def __contains__(self, val):
if not isinstance(val, Integral) and np.isnan(val):
return False
left_cmp = operator.lt if self.closed in ("left", "both") else operator.le
right_cmp = operator.gt if self.closed in ("right", "both") else operator.ge
left = -np.inf if self.left is None else self.left
right = np.inf if self.right is None else self.right
if left_cmp(val, left):
return False
if right_cmp(val, right):
return False
return True
def is_satisfied_by(self, val):
if not isinstance(val, self.type):
return False
return val in self
def __str__(self):
type_str = "an int" if self.type is Integral else "a float"
left_bracket = "[" if self.closed in ("left", "both") else "("
left_bound = "-inf" if self.left is None else self.left
right_bound = "inf" if self.right is None else self.right
right_bracket = "]" if self.closed in ("right", "both") else ")"
# better repr if the bounds were given as integers
if not self.type == Integral and isinstance(self.left, Real):
left_bound = float(left_bound)
if not self.type == Integral and isinstance(self.right, Real):
right_bound = float(right_bound)
return (
f"{type_str} in the range "
f"{left_bracket}{left_bound}, {right_bound}{right_bracket}"
)
class _ArrayLikes(_Constraint):
"""Constraint representing array-likes"""
def is_satisfied_by(self, val):
return _is_arraylike_not_scalar(val)
def __str__(self):
return "an array-like"
class _SparseMatrices(_Constraint):
"""Constraint representing sparse matrices."""
def is_satisfied_by(self, val):
return issparse(val)
def __str__(self):
return "a sparse matrix"
class _Callables(_Constraint):
"""Constraint representing callables."""
def is_satisfied_by(self, val):
return callable(val)
def __str__(self):
return "a callable"
class _RandomStates(_Constraint):
"""Constraint representing random states.
Convenience class for
[Interval(Integral, 0, 2**32 - 1, closed="both"), np.random.RandomState, None]
"""
def __init__(self):
super().__init__()
self._constraints = [
Interval(Integral, 0, 2**32 - 1, closed="both"),
_InstancesOf(np.random.RandomState),
_NoneConstraint(),
]
def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
f" {self._constraints[-1]}"
)
class _Booleans(_Constraint):
"""Constraint representing boolean likes.
Convenience class for
[bool, np.bool_]
"""
def __init__(self):
super().__init__()
self._constraints = [
_InstancesOf(bool),
_InstancesOf(np.bool_),
]
def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
f" {self._constraints[-1]}"
)
class _VerboseHelper(_Constraint):
"""Helper constraint for the verbose parameter.
Convenience class for
[Interval(Integral, 0, None, closed="left"), bool, numpy.bool_]
"""
def __init__(self):
super().__init__()
self._constraints = [
Interval(Integral, 0, None, closed="left"),
_InstancesOf(bool),
_InstancesOf(np.bool_),
]
def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
f" {self._constraints[-1]}"
)
class MissingValues(_Constraint):
"""Helper constraint for the `missing_values` parameters.
Convenience for
[
Integral,
Interval(Real, None, None, closed="both"),
str, # when numeric_only is False
None, # when numeric_only is False
_NanConstraint(),
_PandasNAConstraint(),
]
Parameters
----------
numeric_only : bool, default=False
Whether to consider only numeric missing value markers.
"""
def __init__(self, numeric_only=False):
super().__init__()
self.numeric_only = numeric_only
self._constraints = [
_InstancesOf(Integral),
# we use an interval of Real to ignore np.nan that has its own constraint
Interval(Real, None, None, closed="both"),
_NanConstraint(),
_PandasNAConstraint(),
]
if not self.numeric_only:
self._constraints.extend([_InstancesOf(str), _NoneConstraint()])
def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
f" {self._constraints[-1]}"
)
class HasMethods(_Constraint):
"""Constraint representing objects that expose specific methods.
It is useful for parameters following a protocol and where we don't want to impose
an affiliation to a specific module or class.
Parameters
----------
methods : str or list of str
The method(s) that the object is expected to expose.
"""
@validate_params(
{"methods": [str, list]},
prefer_skip_nested_validation=True,
)
def __init__(self, methods):
super().__init__()
if isinstance(methods, str):
methods = [methods]
self.methods = methods
def is_satisfied_by(self, val):
return all(callable(getattr(val, method, None)) for method in self.methods)
def __str__(self):
if len(self.methods) == 1:
methods = f"{self.methods[0]!r}"
else:
methods = (
f"{', '.join([repr(m) for m in self.methods[:-1]])} and"
f" {self.methods[-1]!r}"
)
return f"an object implementing {methods}"
class _IterablesNotString(_Constraint):
"""Constraint representing iterables that are not strings."""
def is_satisfied_by(self, val):
return isinstance(val, Iterable) and not isinstance(val, str)
def __str__(self):
return "an iterable"
class _CVObjects(_Constraint):
"""Constraint representing cv objects.
Convenient class for
[
Interval(Integral, 2, None, closed="left"),
HasMethods(["split", "get_n_splits"]),
_IterablesNotString(),
None,
]
"""
def __init__(self):
super().__init__()
self._constraints = [
Interval(Integral, 2, None, closed="left"),
HasMethods(["split", "get_n_splits"]),
_IterablesNotString(),
_NoneConstraint(),
]
def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
def __str__(self):
return (
f"{', '.join([str(c) for c in self._constraints[:-1]])} or"
f" {self._constraints[-1]}"
)
class Hidden:
"""Class encapsulating a constraint not meant to be exposed to the user.
Parameters
----------
constraint : str or _Constraint instance
The constraint to be used internally.
"""
def __init__(self, constraint):
self.constraint = constraint
def generate_invalid_param_val(constraint):
"""Return a value that does not satisfy the constraint.
Raises a NotImplementedError if there exists no invalid value for this constraint.
This is only useful for testing purpose.
Parameters
----------
constraint : _Constraint instance
The constraint to generate a value for.
Returns
-------
val : object
A value that does not satisfy the constraint.
"""
if isinstance(constraint, StrOptions):
return f"not {' or '.join(constraint.options)}"
if isinstance(constraint, MissingValues):
return np.array([1, 2, 3])
if isinstance(constraint, _VerboseHelper):
return -1
if isinstance(constraint, HasMethods):
return type("HasNotMethods", (), {})()
if isinstance(constraint, _IterablesNotString):
return "a string"
if isinstance(constraint, _CVObjects):
return "not a cv object"
if isinstance(constraint, Interval) and constraint.type is Integral:
if constraint.left is not None:
return constraint.left - 1
if constraint.right is not None:
return constraint.right + 1
# There's no integer outside (-inf, +inf)
raise NotImplementedError
if isinstance(constraint, Interval) and constraint.type in (Real, RealNotInt):
if constraint.left is not None:
return constraint.left - 1e-6
if constraint.right is not None:
return constraint.right + 1e-6
# bounds are -inf, +inf
if constraint.closed in ("right", "neither"):
return -np.inf
if constraint.closed in ("left", "neither"):
return np.inf
# interval is [-inf, +inf]
return np.nan
raise NotImplementedError
def generate_valid_param(constraint):
"""Return a value that does satisfy a constraint.
This is only useful for testing purpose.
Parameters
----------
constraint : Constraint instance
The constraint to generate a value for.
Returns
-------
val : object
A value that does satisfy the constraint.
"""
if isinstance(constraint, _ArrayLikes):
return np.array([1, 2, 3])
if isinstance(constraint, _SparseMatrices):
return csr_matrix([[0, 1], [1, 0]])
if isinstance(constraint, _RandomStates):
return np.random.RandomState(42)
if isinstance(constraint, _Callables):
return lambda x: x
if isinstance(constraint, _NoneConstraint):
return None
if isinstance(constraint, _InstancesOf):
if constraint.type is np.ndarray:
# special case for ndarray since it can't be instantiated without arguments
return np.array([1, 2, 3])
if constraint.type in (Integral, Real):
# special case for Integral and Real since they are abstract classes
return 1
return constraint.type()
if isinstance(constraint, _Booleans):
return True
if isinstance(constraint, _VerboseHelper):
return 1
if isinstance(constraint, MissingValues) and constraint.numeric_only:
return np.nan
if isinstance(constraint, MissingValues) and not constraint.numeric_only:
return "missing"
if isinstance(constraint, HasMethods):
return type(
"ValidHasMethods", (), {m: lambda self: None for m in constraint.methods}
)()
if isinstance(constraint, _IterablesNotString):
return [1, 2, 3]
if isinstance(constraint, _CVObjects):
return 5
if isinstance(constraint, Options): # includes StrOptions
for option in constraint.options:
return option
if isinstance(constraint, Interval):
interval = constraint
if interval.left is None and interval.right is None:
return 0
elif interval.left is None:
return interval.right - 1
elif interval.right is None:
return interval.left + 1
else:
if interval.type is Real:
return (interval.left + interval.right) / 2
else:
return interval.left + 1
raise ValueError(f"Unknown constraint type: {constraint}")

View File

@@ -0,0 +1,419 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from collections.abc import Mapping
import numpy as np
from . import check_consistent_length
from ._optional_dependencies import check_matplotlib_support
from ._response import _get_response_values_binary
from .fixes import parse_version
from .multiclass import type_of_target
from .validation import _check_pos_label_consistency, _num_samples
class _BinaryClassifierCurveDisplayMixin:
"""Mixin class to be used in Displays requiring a binary classifier.
The aim of this class is to centralize some validations regarding the estimator and
the target and gather the response of the estimator.
"""
def _validate_plot_params(self, *, ax=None, name=None):
check_matplotlib_support(f"{self.__class__.__name__}.plot")
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
# Display classes are in process of changing from `estimator_name` to `name`.
# Try old attr name: `estimator_name` first.
if name is None:
name = getattr(self, "estimator_name", getattr(self, "name", None))
return ax, ax.figure, name
@classmethod
def _validate_and_get_response_values(
cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None
):
check_matplotlib_support(f"{cls.__name__}.from_estimator")
name = estimator.__class__.__name__ if name is None else name
y_pred, pos_label = _get_response_values_binary(
estimator,
X,
response_method=response_method,
pos_label=pos_label,
)
return y_pred, pos_label, name
@classmethod
def _validate_from_predictions_params(
cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None
):
check_matplotlib_support(f"{cls.__name__}.from_predictions")
if type_of_target(y_true) != "binary":
raise ValueError(
f"The target y is not binary. Got {type_of_target(y_true)} type of"
" target."
)
check_consistent_length(y_true, y_pred, sample_weight)
pos_label = _check_pos_label_consistency(pos_label, y_true)
name = name if name is not None else "Classifier"
return pos_label, name
@classmethod
def _validate_from_cv_results_params(
cls,
cv_results,
X,
y,
*,
sample_weight,
pos_label,
):
check_matplotlib_support(f"{cls.__name__}.from_cv_results")
required_keys = {"estimator", "indices"}
if not all(key in cv_results for key in required_keys):
raise ValueError(
"`cv_results` does not contain one of the following required keys: "
f"{required_keys}. Set explicitly the parameters "
"`return_estimator=True` and `return_indices=True` to the function"
"`cross_validate`."
)
train_size, test_size = (
len(cv_results["indices"]["train"][0]),
len(cv_results["indices"]["test"][0]),
)
if _num_samples(X) != train_size + test_size:
raise ValueError(
"`X` does not contain the correct number of samples. "
f"Expected {train_size + test_size}, got {_num_samples(X)}."
)
if type_of_target(y) != "binary":
raise ValueError(
f"The target `y` is not binary. Got {type_of_target(y)} type of target."
)
check_consistent_length(X, y, sample_weight)
try:
pos_label = _check_pos_label_consistency(pos_label, y)
except ValueError as e:
# Adapt error message
raise ValueError(str(e).replace("y_true", "y"))
return pos_label
@staticmethod
def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name):
"""Helper to get legend label using `name` and `legend_metric`"""
if curve_legend_metric is not None and curve_name is not None:
label = f"{curve_name} ({legend_metric_name} = {curve_legend_metric:0.2f})"
elif curve_legend_metric is not None:
label = f"{legend_metric_name} = {curve_legend_metric:0.2f}"
elif curve_name is not None:
label = curve_name
else:
label = None
return label
@staticmethod
def _validate_curve_kwargs(
n_curves,
name,
legend_metric,
legend_metric_name,
curve_kwargs,
**kwargs,
):
"""Get validated line kwargs for each curve.
Parameters
----------
n_curves : int
Number of curves.
name : list of str or None
Name for labeling legend entries.
legend_metric : dict
Dictionary with "mean" and "std" keys, or "metric" key of metric
values for each curve. If None, "label" will not contain metric values.
legend_metric_name : str
Name of the summary value provided in `legend_metrics`.
curve_kwargs : dict or list of dict or None
Dictionary with keywords passed to the matplotlib's `plot` function
to draw the individual curves. If a list is provided, the
parameters are applied to the curves sequentially. If a single
dictionary is provided, the same parameters are applied to all
curves.
**kwargs : dict
Deprecated. Keyword arguments to be passed to matplotlib's `plot`.
"""
# TODO(1.9): Remove deprecated **kwargs
if curve_kwargs and kwargs:
raise ValueError(
"Cannot provide both `curve_kwargs` and `kwargs`. `**kwargs` is "
"deprecated in 1.7 and will be removed in 1.9. Pass all matplotlib "
"arguments to `curve_kwargs` as a dictionary."
)
if kwargs:
warnings.warn(
"`**kwargs` is deprecated and will be removed in 1.9. Pass all "
"matplotlib arguments to `curve_kwargs` as a dictionary instead.",
FutureWarning,
)
curve_kwargs = kwargs
if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves:
raise ValueError(
f"`curve_kwargs` must be None, a dictionary or a list of length "
f"{n_curves}. Got: {curve_kwargs}."
)
# Ensure valid `name` and `curve_kwargs` combination.
if (
isinstance(name, list)
and len(name) != 1
and not isinstance(curve_kwargs, list)
):
raise ValueError(
"To avoid labeling individual curves that have the same appearance, "
f"`curve_kwargs` should be a list of {n_curves} dictionaries. "
"Alternatively, set `name` to `None` or a single string to label "
"a single legend entry with mean ROC AUC score of all curves."
)
# Ensure `name` is of the correct length
if isinstance(name, str):
name = [name]
if isinstance(name, list) and len(name) == 1:
name = name * n_curves
name = [None] * n_curves if name is None else name
# Ensure `curve_kwargs` is of correct length
if isinstance(curve_kwargs, Mapping):
curve_kwargs = [curve_kwargs] * n_curves
default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"}
if curve_kwargs is None:
if n_curves > 1:
curve_kwargs = [default_multi_curve_kwargs] * n_curves
else:
curve_kwargs = [{}]
labels = []
if "mean" in legend_metric:
label_aggregate = _BinaryClassifierCurveDisplayMixin._get_legend_label(
legend_metric["mean"], name[0], legend_metric_name
)
# Note: "std" always `None` when "mean" is `None` - no metric value added
# to label in this case
if legend_metric["std"] is not None:
# Add the "+/- std" to the end (in brackets if name provided)
if name[0] is not None:
label_aggregate = (
label_aggregate[:-1] + f" +/- {legend_metric['std']:0.2f})"
)
else:
label_aggregate = (
label_aggregate + f" +/- {legend_metric['std']:0.2f}"
)
# Add `label` for first curve only, set to `None` for remaining curves
labels.extend([label_aggregate] + [None] * (n_curves - 1))
else:
for curve_legend_metric, curve_name in zip(legend_metric["metric"], name):
labels.append(
_BinaryClassifierCurveDisplayMixin._get_legend_label(
curve_legend_metric, curve_name, legend_metric_name
)
)
curve_kwargs_ = [
_validate_style_kwargs({"label": label}, curve_kwargs[fold_idx])
for fold_idx, label in enumerate(labels)
]
return curve_kwargs_
def _validate_score_name(score_name, scoring, negate_score):
"""Validate the `score_name` parameter.
If `score_name` is provided, we just return it as-is.
If `score_name` is `None`, we use `Score` if `negate_score` is `False` and
`Negative score` otherwise.
If `score_name` is a string or a callable, we infer the name. We replace `_` by
spaces and capitalize the first letter. We remove `neg_` and replace it by
`"Negative"` if `negate_score` is `False` or just remove it otherwise.
"""
if score_name is not None:
return score_name
elif scoring is None:
return "Negative score" if negate_score else "Score"
else:
score_name = scoring.__name__ if callable(scoring) else scoring
if negate_score:
if score_name.startswith("neg_"):
score_name = score_name[4:]
else:
score_name = f"Negative {score_name}"
elif score_name.startswith("neg_"):
score_name = f"Negative {score_name[4:]}"
score_name = score_name.replace("_", " ")
return score_name.capitalize()
def _interval_max_min_ratio(data):
"""Compute the ratio between the largest and smallest inter-point distances.
A value larger than 5 typically indicates that the parameter range would
better be displayed with a log scale while a linear scale would be more
suitable otherwise.
"""
diff = np.diff(np.sort(data))
return diff.max() / diff.min()
def _validate_style_kwargs(default_style_kwargs, user_style_kwargs):
"""Create valid style kwargs by avoiding Matplotlib alias errors.
Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and
'ls', are specified together. To avoid this, we automatically keep only the one
specified by the user and raise an error if the user specifies both.
Parameters
----------
default_style_kwargs : dict
The Matplotlib style kwargs used by default in the scikit-learn display.
user_style_kwargs : dict
The user-defined Matplotlib style kwargs.
Returns
-------
valid_style_kwargs : dict
The validated style kwargs taking into account both default and user-defined
Matplotlib style kwargs.
"""
invalid_to_valid_kw = {
"ls": "linestyle",
"c": "color",
"ec": "edgecolor",
"fc": "facecolor",
"lw": "linewidth",
"mec": "markeredgecolor",
"mfcalt": "markerfacecoloralt",
"ms": "markersize",
"mew": "markeredgewidth",
"mfc": "markerfacecolor",
"aa": "antialiased",
"ds": "drawstyle",
"font": "fontproperties",
"family": "fontfamily",
"name": "fontname",
"size": "fontsize",
"stretch": "fontstretch",
"style": "fontstyle",
"variant": "fontvariant",
"weight": "fontweight",
"ha": "horizontalalignment",
"va": "verticalalignment",
"ma": "multialignment",
}
for invalid_key, valid_key in invalid_to_valid_kw.items():
if invalid_key in user_style_kwargs and valid_key in user_style_kwargs:
raise TypeError(
f"Got both {invalid_key} and {valid_key}, which are aliases of one "
"another"
)
valid_style_kwargs = default_style_kwargs.copy()
for key in user_style_kwargs.keys():
if key in invalid_to_valid_kw:
valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key]
else:
valid_style_kwargs[key] = user_style_kwargs[key]
return valid_style_kwargs
def _despine(ax):
"""Remove the top and right spines of the plot.
Parameters
----------
ax : matplotlib.axes.Axes
The axes of the plot to despine.
"""
for s in ["top", "right"]:
ax.spines[s].set_visible(False)
for s in ["bottom", "left"]:
ax.spines[s].set_bounds(0, 1)
def _deprecate_estimator_name(estimator_name, name, version):
"""Deprecate `estimator_name` in favour of `name`."""
version = parse_version(version)
version_remove = f"{version.major}.{version.minor + 2}"
if estimator_name != "deprecated":
if name:
raise ValueError(
"Cannot provide both `estimator_name` and `name`. `estimator_name` "
f"is deprecated in {version} and will be removed in {version_remove}. "
"Use `name` only."
)
warnings.warn(
f"`estimator_name` is deprecated in {version} and will be removed in "
f"{version_remove}. Use `name` instead.",
FutureWarning,
)
return estimator_name
return name
def _convert_to_list_leaving_none(param):
"""Convert parameters to a list, leaving `None` as is."""
if param is None:
return None
if isinstance(param, list):
return param
return [param]
def _check_param_lengths(required, optional, class_name):
"""Check required and optional parameters are of the same length."""
optional_provided = {}
for name, param in optional.items():
if isinstance(param, list):
optional_provided[name] = param
all_params = {**required, **optional_provided}
if len({len(param) for param in all_params.values()}) > 1:
param_keys = [key for key in all_params.keys()]
# Note: below code requires `len(param_keys) >= 2`, which is the case for all
# display classes
params_formatted = " and ".join([", ".join(param_keys[:-1]), param_keys[-1]])
or_plot = ""
if "'name' (or self.name)" in param_keys:
or_plot = " (or `plot`)"
lengths_formatted = ", ".join(
f"{key}: {len(value)}" for key, value in all_params.items()
)
raise ValueError(
f"{params_formatted} from `{class_name}` initialization{or_plot}, "
f"should all be lists of the same length. Got: {lengths_formatted}"
)

View File

@@ -0,0 +1,463 @@
"""This module contains the _EstimatorPrettyPrinter class used in
BaseEstimator.__repr__ for pretty-printing estimators"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
# 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018 Python Software Foundation;
# All Rights Reserved
# Authors: Fred L. Drake, Jr. <fdrake@acm.org> (built-in CPython pprint module)
# Nicolas Hug (scikit-learn specific changes)
# License: PSF License version 2 (see below)
# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
# --------------------------------------------
# 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"),
# and the Individual or Organization ("Licensee") accessing and otherwise
# using this software ("Python") in source or binary form and its associated
# documentation.
# 2. Subject to the terms and conditions of this License Agreement, PSF hereby
# grants Licensee a nonexclusive, royalty-free, world-wide license to
# reproduce, analyze, test, perform and/or display publicly, prepare
# derivative works, distribute, and otherwise use Python alone or in any
# derivative version, provided, however, that PSF's License Agreement and
# PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004,
# 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016,
# 2017, 2018 Python Software Foundation; All Rights Reserved" are retained in
# Python alone or in any derivative version prepared by Licensee.
# 3. In the event Licensee prepares a derivative work that is based on or
# incorporates Python or any part thereof, and wants to make the derivative
# work available to others as provided herein, then Licensee hereby agrees to
# include in any such work a brief summary of the changes made to Python.
# 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES
# NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT
# NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF
# MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
# PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY
# INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF
# MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE
# THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
# 6. This License Agreement will automatically terminate upon a material
# breach of its terms and conditions.
# 7. Nothing in this License Agreement shall be deemed to create any
# relationship of agency, partnership, or joint venture between PSF and
# Licensee. This License Agreement does not grant permission to use PSF
# trademarks or trade name in a trademark sense to endorse or promote products
# or services of Licensee, or any third party.
# 8. By copying, installing or otherwise using Python, Licensee agrees to be
# bound by the terms and conditions of this License Agreement.
# Brief summary of changes to original code:
# - "compact" parameter is supported for dicts, not just lists or tuples
# - estimators have a custom handler, they're not just treated as objects
# - long sequences (lists, tuples, dict items) with more than N elements are
# shortened using ellipsis (', ...') at the end.
import inspect
import pprint
from .._config import get_config
from ..base import BaseEstimator
from ._missing import is_scalar_nan
class KeyValTuple(tuple):
"""Dummy class for correctly rendering key-value tuples from dicts."""
def __repr__(self):
# needed for _dispatch[tuple.__repr__] not to be overridden
return super().__repr__()
class KeyValTupleParam(KeyValTuple):
"""Dummy class for correctly rendering key-value tuples from parameters."""
pass
def _changed_params(estimator):
"""Return dict (param_name: value) of parameters that were given to
estimator with non-default values."""
params = estimator.get_params(deep=False)
init_func = getattr(estimator.__init__, "deprecated_original", estimator.__init__)
init_params = inspect.signature(init_func).parameters
init_params = {name: param.default for name, param in init_params.items()}
def has_changed(k, v):
if k not in init_params: # happens if k is part of a **kwargs
return True
if init_params[k] == inspect._empty: # k has no default value
return True
# try to avoid calling repr on nested estimators
if isinstance(v, BaseEstimator) and v.__class__ != init_params[k].__class__:
return True
# Use repr as a last resort. It may be expensive.
if repr(v) != repr(init_params[k]) and not (
is_scalar_nan(init_params[k]) and is_scalar_nan(v)
):
return True
return False
return {k: v for k, v in params.items() if has_changed(k, v)}
class _EstimatorPrettyPrinter(pprint.PrettyPrinter):
"""Pretty Printer class for estimator objects.
This extends the pprint.PrettyPrinter class, because:
- we need estimators to be printed with their parameters, e.g.
Estimator(param1=value1, ...) which is not supported by default.
- the 'compact' parameter of PrettyPrinter is ignored for dicts, which
may lead to very long representations that we want to avoid.
Quick overview of pprint.PrettyPrinter (see also
https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
- the entry point is the _format() method which calls format() (overridden
here)
- format() directly calls _safe_repr() for a first try at rendering the
object
- _safe_repr formats the whole object recursively, only calling itself,
not caring about line length or anything
- back to _format(), if the output string is too long, _format() then calls
the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
the type of the object. This where the line length and the compact
parameters are taken into account.
- those _pprint_TYPE() methods will internally use the format() method for
rendering the nested objects of an object (e.g. the elements of a list)
In the end, everything has to be implemented twice: in _safe_repr and in
the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
straightforward to extend (especially when we want a compact output), so
the code is a bit convoluted.
This class overrides:
- format() to support the changed_only parameter
- _safe_repr to support printing of estimators (for when they fit on a
single line)
- _format_dict_items so that dict are correctly 'compacted'
- _format_items so that ellipsis is used on long lists and tuples
When estimators cannot be printed on a single line, the builtin _format()
will call _pprint_estimator() because it was registered to do so (see
_dispatch[BaseEstimator.__repr__] = _pprint_estimator).
both _format_dict_items() and _pprint_estimator() use the
_format_params_or_dict_items() method that will format parameters and
key-value pairs respecting the compact parameter. This method needs another
subroutine _pprint_key_val_tuple() used when a parameter or a key-value
pair is too long to fit on a single line. This subroutine is called in
_format() and is registered as well in the _dispatch dict (just like
_pprint_estimator). We had to create the two classes KeyValTuple and
KeyValTupleParam for this.
"""
def __init__(
self,
indent=1,
width=80,
depth=None,
stream=None,
*,
compact=False,
indent_at_name=True,
n_max_elements_to_show=None,
):
super().__init__(indent, width, depth, stream, compact=compact)
self._indent_at_name = indent_at_name
if self._indent_at_name:
self._indent_per_level = 1 # ignore indent param
self._changed_only = get_config()["print_changed_only"]
# Max number of elements in a list, dict, tuple until we start using
# ellipsis. This also affects the number of arguments of an estimators
# (they are treated as dicts)
self.n_max_elements_to_show = n_max_elements_to_show
def format(self, object, context, maxlevels, level):
return _safe_repr(
object, context, maxlevels, level, changed_only=self._changed_only
)
def _pprint_estimator(self, object, stream, indent, allowance, context, level):
stream.write(object.__class__.__name__ + "(")
if self._indent_at_name:
indent += len(object.__class__.__name__)
if self._changed_only:
params = _changed_params(object)
else:
params = object.get_params(deep=False)
self._format_params(
sorted(params.items()), stream, indent, allowance + 1, context, level
)
stream.write(")")
def _format_dict_items(self, items, stream, indent, allowance, context, level):
return self._format_params_or_dict_items(
items, stream, indent, allowance, context, level, is_dict=True
)
def _format_params(self, items, stream, indent, allowance, context, level):
return self._format_params_or_dict_items(
items, stream, indent, allowance, context, level, is_dict=False
)
def _format_params_or_dict_items(
self, object, stream, indent, allowance, context, level, is_dict
):
"""Format dict items or parameters respecting the compact=True
parameter. For some reason, the builtin rendering of dict items doesn't
respect compact=True and will use one line per key-value if all cannot
fit in a single line.
Dict items will be rendered as <'key': value> while params will be
rendered as <key=value>. The implementation is mostly copy/pasting from
the builtin _format_items().
This also adds ellipsis if the number of items is greater than
self.n_max_elements_to_show.
"""
write = stream.write
indent += self._indent_per_level
delimnl = ",\n" + " " * indent
delim = ""
width = max_width = self._width - indent + 1
it = iter(object)
try:
next_ent = next(it)
except StopIteration:
return
last = False
n_items = 0
while not last:
if n_items == self.n_max_elements_to_show:
write(", ...")
break
n_items += 1
ent = next_ent
try:
next_ent = next(it)
except StopIteration:
last = True
max_width -= allowance
width -= allowance
if self._compact:
k, v = ent
krepr = self._repr(k, context, level)
vrepr = self._repr(v, context, level)
if not is_dict:
krepr = krepr.strip("'")
middle = ": " if is_dict else "="
rep = krepr + middle + vrepr
w = len(rep) + 2
if width < w:
width = max_width
if delim:
delim = delimnl
if width >= w:
width -= w
write(delim)
delim = ", "
write(rep)
continue
write(delim)
delim = delimnl
class_ = KeyValTuple if is_dict else KeyValTupleParam
self._format(
class_(ent), stream, indent, allowance if last else 1, context, level
)
def _format_items(self, items, stream, indent, allowance, context, level):
"""Format the items of an iterable (list, tuple...). Same as the
built-in _format_items, with support for ellipsis if the number of
elements is greater than self.n_max_elements_to_show.
"""
write = stream.write
indent += self._indent_per_level
if self._indent_per_level > 1:
write((self._indent_per_level - 1) * " ")
delimnl = ",\n" + " " * indent
delim = ""
width = max_width = self._width - indent + 1
it = iter(items)
try:
next_ent = next(it)
except StopIteration:
return
last = False
n_items = 0
while not last:
if n_items == self.n_max_elements_to_show:
write(", ...")
break
n_items += 1
ent = next_ent
try:
next_ent = next(it)
except StopIteration:
last = True
max_width -= allowance
width -= allowance
if self._compact:
rep = self._repr(ent, context, level)
w = len(rep) + 2
if width < w:
width = max_width
if delim:
delim = delimnl
if width >= w:
width -= w
write(delim)
delim = ", "
write(rep)
continue
write(delim)
delim = delimnl
self._format(ent, stream, indent, allowance if last else 1, context, level)
def _pprint_key_val_tuple(self, object, stream, indent, allowance, context, level):
"""Pretty printing for key-value tuples from dict or parameters."""
k, v = object
rep = self._repr(k, context, level)
if isinstance(object, KeyValTupleParam):
rep = rep.strip("'")
middle = "="
else:
middle = ": "
stream.write(rep)
stream.write(middle)
self._format(
v, stream, indent + len(rep) + len(middle), allowance, context, level
)
# Note: need to copy _dispatch to prevent instances of the builtin
# PrettyPrinter class to call methods of _EstimatorPrettyPrinter (see issue
# 12906)
# mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
_dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore[attr-defined]
_dispatch[BaseEstimator.__repr__] = _pprint_estimator
_dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
def _safe_repr(object, context, maxlevels, level, changed_only=False):
"""Same as the builtin _safe_repr, with added support for Estimator
objects."""
typ = type(object)
if typ in pprint._builtin_scalars:
return repr(object), True, False
r = getattr(typ, "__repr__", None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}", True, False
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}", False, objid in context
if objid in context:
return pprint._recursion(object), False, True
context[objid] = 1
readable = True
recursive = False
components = []
append = components.append
level += 1
saferepr = _safe_repr
items = sorted(object.items(), key=pprint._safe_tuple)
for k, v in items:
krepr, kreadable, krecur = saferepr(
k, context, maxlevels, level, changed_only=changed_only
)
vrepr, vreadable, vrecur = saferepr(
v, context, maxlevels, level, changed_only=changed_only
)
append("%s: %s" % (krepr, vrepr))
readable = readable and kreadable and vreadable
if krecur or vrecur:
recursive = True
del context[objid]
return "{%s}" % ", ".join(components), readable, recursive
if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__
):
if issubclass(typ, list):
if not object:
return "[]", True, False
format = "[%s]"
elif len(object) == 1:
format = "(%s,)"
else:
if not object:
return "()", True, False
format = "(%s)"
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "...", False, objid in context
if objid in context:
return pprint._recursion(object), False, True
context[objid] = 1
readable = True
recursive = False
components = []
append = components.append
level += 1
for o in object:
orepr, oreadable, orecur = _safe_repr(
o, context, maxlevels, level, changed_only=changed_only
)
append(orepr)
if not oreadable:
readable = False
if orecur:
recursive = True
del context[objid]
return format % ", ".join(components), readable, recursive
if issubclass(typ, BaseEstimator):
objid = id(object)
if maxlevels and level >= maxlevels:
return f"{typ.__name__}(...)", False, objid in context
if objid in context:
return pprint._recursion(object), False, True
context[objid] = 1
readable = True
recursive = False
if changed_only:
params = _changed_params(object)
else:
params = object.get_params(deep=False)
components = []
append = components.append
level += 1
saferepr = _safe_repr
items = sorted(params.items(), key=pprint._safe_tuple)
for k, v in items:
krepr, kreadable, krecur = saferepr(
k, context, maxlevels, level, changed_only=changed_only
)
vrepr, vreadable, vrecur = saferepr(
v, context, maxlevels, level, changed_only=changed_only
)
append("%s=%s" % (krepr.strip("'"), vrepr))
readable = readable and kreadable and vreadable
if krecur or vrecur:
recursive = True
del context[objid]
return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
rep = repr(object)
return rep, (rep and not rep.startswith("<")), False

View File

@@ -0,0 +1,34 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from ._typedefs cimport uint32_t
cdef inline uint32_t DEFAULT_SEED = 1
cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
# particularly tiny on Windows/MSVC.
# It corresponds to the maximum representable value for
# 32-bit signed integers (i.e. 2^31 - 1).
RAND_R_MAX = 2147483647
# rand_r replacement using a 32bit XorShift generator
# See http://www.jstatsoft.org/v08/i14/paper for details
cdef inline uint32_t our_rand_r(uint32_t* seed) nogil:
"""Generate a pseudo-random np.uint32 from a np.uint32 seed"""
# seed shouldn't ever be 0.
if (seed[0] == 0):
seed[0] = DEFAULT_SEED
seed[0] ^= <uint32_t>(seed[0] << 13)
seed[0] ^= <uint32_t>(seed[0] >> 17)
seed[0] ^= <uint32_t>(seed[0] << 5)
# Use the modulo to make sure that we don't return a values greater than the
# maximum representable value for signed 32bit integers (i.e. 2^31 - 1).
# Note that the parenthesis are needed to avoid overflow: here
# RAND_R_MAX is cast to uint32_t before 1 is added.
return seed[0] % ((<uint32_t>RAND_R_MAX) + 1)

View File

@@ -0,0 +1,355 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
"""
Random utility function
=======================
This module complements missing features of ``numpy.random``.
The module contains:
* Several algorithms to sample integers without replacement.
* Fast rand_r alternative based on xor shifts
"""
import numpy as np
from . import check_random_state
from ._typedefs cimport intp_t
cdef uint32_t DEFAULT_SEED = 1
# Compatibility type to always accept the default int type used by NumPy, both
# before and after NumPy 2. On Windows, `long` does not always match `inp_t`.
# See the comments in the `sample_without_replacement` Python function for more
# details.
ctypedef fused default_int:
intp_t
long
cpdef _sample_without_replacement_check_input(default_int n_population,
default_int n_samples):
""" Check that input are consistent for sample_without_replacement"""
if n_population < 0:
raise ValueError('n_population should be greater than 0, got %s.'
% n_population)
if n_samples > n_population:
raise ValueError('n_population should be greater or equal than '
'n_samples, got n_samples > n_population (%s > %s)'
% (n_samples, n_population))
cpdef _sample_without_replacement_with_tracking_selection(
default_int n_population,
default_int n_samples,
random_state=None):
r"""Sample integers without replacement.
Select n_samples integers from the set [0, n_population) without
replacement.
Time complexity:
- Worst-case: unbounded
- Average-case:
O(O(np.random.randint) * \sum_{i=1}^n_samples 1 /
(1 - i / n_population)))
<= O(O(np.random.randint) *
n_population * ln((n_population - 2)
/(n_population - 1 - n_samples)))
<= O(O(np.random.randint) *
n_population * 1 / (1 - n_samples / n_population))
Space complexity of O(n_samples) in a python set.
Parameters
----------
n_population : int
The size of the set to sample from.
n_samples : int
The number of integer to sample.
random_state : int, RandomState instance or None, default=None
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Returns
-------
out : ndarray of shape (n_samples,)
The sampled subsets of integer.
"""
_sample_without_replacement_check_input(n_population, n_samples)
cdef default_int i
cdef default_int j
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
rng = check_random_state(random_state)
rng_randint = rng.randint
# The following line of code are heavily inspired from python core,
# more precisely of random.sample.
cdef set selected = set()
for i in range(n_samples):
j = rng_randint(n_population)
while j in selected:
j = rng_randint(n_population)
selected.add(j)
out[i] = j
return np.asarray(out)
cpdef _sample_without_replacement_with_pool(default_int n_population,
default_int n_samples,
random_state=None):
"""Sample integers without replacement.
Select n_samples integers from the set [0, n_population) without
replacement.
Time complexity: O(n_population + O(np.random.randint) * n_samples)
Space complexity of O(n_population + n_samples).
Parameters
----------
n_population : int
The size of the set to sample from.
n_samples : int
The number of integer to sample.
random_state : int, RandomState instance or None, default=None
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Returns
-------
out : ndarray of shape (n_samples,)
The sampled subsets of integer.
"""
_sample_without_replacement_check_input(n_population, n_samples)
cdef default_int i
cdef default_int j
cdef default_int[::1] out = np.empty((n_samples,), dtype=int)
cdef default_int[::1] pool = np.empty((n_population,), dtype=int)
rng = check_random_state(random_state)
rng_randint = rng.randint
# Initialize the pool
for i in range(n_population):
pool[i] = i
# The following line of code are heavily inspired from python core,
# more precisely of random.sample.
for i in range(n_samples):
j = rng_randint(n_population - i) # invariant: non-selected at [0,n-i)
out[i] = pool[j]
pool[j] = pool[n_population - i - 1] # move non-selected item into vacancy
return np.asarray(out)
cpdef _sample_without_replacement_with_reservoir_sampling(
default_int n_population,
default_int n_samples,
random_state=None
):
"""Sample integers without replacement.
Select n_samples integers from the set [0, n_population) without
replacement.
Time complexity of
O((n_population - n_samples) * O(np.random.randint) + n_samples)
Space complexity of O(n_samples)
Parameters
----------
n_population : int
The size of the set to sample from.
n_samples : int
The number of integer to sample.
random_state : int, RandomState instance or None, default=None
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Returns
-------
out : ndarray of shape (n_samples,)
The sampled subsets of integer. The order of the items is not
necessarily random. Use a random permutation of the array if the order
of the items has to be randomized.
"""
_sample_without_replacement_check_input(n_population, n_samples)
cdef default_int i
cdef default_int j
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
rng = check_random_state(random_state)
rng_randint = rng.randint
# This cython implementation is based on the one of Robert Kern:
# http://mail.scipy.org/pipermail/numpy-discussion/2010-December/
# 054289.html
#
for i in range(n_samples):
out[i] = i
for i from n_samples <= i < n_population:
j = rng_randint(0, i + 1)
if j < n_samples:
out[j] = i
return np.asarray(out)
cdef _sample_without_replacement(default_int n_population,
default_int n_samples,
method="auto",
random_state=None):
"""Sample integers without replacement.
Private function for the implementation, see sample_without_replacement
documentation for more details.
"""
_sample_without_replacement_check_input(n_population, n_samples)
all_methods = ("auto", "tracking_selection", "reservoir_sampling", "pool")
ratio = <double> n_samples / n_population if n_population != 0.0 else 1.0
# Check ratio and use permutation unless ratio < 0.01 or ratio > 0.99
if method == "auto" and ratio > 0.01 and ratio < 0.99:
rng = check_random_state(random_state)
return rng.permutation(n_population)[:n_samples]
if method == "auto" or method == "tracking_selection":
# TODO the pool based method can also be used.
# however, it requires special benchmark to take into account
# the memory requirement of the array vs the set.
# The value 0.2 has been determined through benchmarking.
if ratio < 0.2:
return _sample_without_replacement_with_tracking_selection(
n_population, n_samples, random_state)
else:
return _sample_without_replacement_with_reservoir_sampling(
n_population, n_samples, random_state)
elif method == "reservoir_sampling":
return _sample_without_replacement_with_reservoir_sampling(
n_population, n_samples, random_state)
elif method == "pool":
return _sample_without_replacement_with_pool(n_population, n_samples,
random_state)
else:
raise ValueError('Expected a method name in %s, got %s. '
% (all_methods, method))
def sample_without_replacement(
object n_population, object n_samples, method="auto", random_state=None):
"""Sample integers without replacement.
Select n_samples integers from the set [0, n_population) without
replacement.
Parameters
----------
n_population : int
The size of the set to sample from.
n_samples : int
The number of integer to sample.
random_state : int, RandomState instance or None, default=None
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
method : {"auto", "tracking_selection", "reservoir_sampling", "pool"}, \
default='auto'
If method == "auto", the ratio of n_samples / n_population is used
to determine which algorithm to use:
If ratio is between 0 and 0.01, tracking selection is used.
If ratio is between 0.01 and 0.99, numpy.random.permutation is used.
If ratio is greater than 0.99, reservoir sampling is used.
The order of the selected integers is undefined. If a random order is
desired, the selected subset should be shuffled.
If method =="tracking_selection", a set based implementation is used
which is suitable for `n_samples` <<< `n_population`.
If method == "reservoir_sampling", a reservoir sampling algorithm is
used which is suitable for high memory constraint or when
O(`n_samples`) ~ O(`n_population`).
The order of the selected integers is undefined. If a random order is
desired, the selected subset should be shuffled.
If method == "pool", a pool based algorithm is particularly fast, even
faster than the tracking selection method. However, a vector containing
the entire population has to be initialized.
If n_samples ~ n_population, the reservoir sampling method is faster.
Returns
-------
out : ndarray of shape (n_samples,)
The sampled subsets of integer. The subset of selected integer might
not be randomized, see the method argument.
Examples
--------
>>> from sklearn.utils.random import sample_without_replacement
>>> sample_without_replacement(10, 5, random_state=42)
array([8, 1, 5, 0, 7])
"""
cdef:
intp_t n_pop_intp, n_samples_intp
long n_pop_long, n_samples_long
# On most platforms `np.int_ is np.intp`. However, before NumPy 2 the
# default integer `np.int_` was a long which is 32bit on 64bit windows
# while `intp` is 64bit on 64bit platforms and 32bit on 32bit ones.
if np.int_ is np.intp:
# Branch always taken on NumPy >=2 (or when not on 64bit windows).
# Cython has different rules for conversion of values to integers.
# For NumPy <1.26.2 AND Cython 3, this first branch requires `int()`
# called explicitly to allow e.g. floats.
n_pop_intp = int(n_population)
n_samples_intp = int(n_samples)
return _sample_without_replacement(
n_pop_intp, n_samples_intp, method, random_state)
else:
# Branch taken on 64bit windows with Numpy<2.0 where `long` is 32bit
n_pop_long = n_population
n_samples_long = n_samples
return _sample_without_replacement(
n_pop_long, n_samples_long, method, random_state)
def _our_rand_r_py(seed):
"""Python utils to test the our_rand_r function"""
cdef uint32_t my_seed = seed
return our_rand_r(&my_seed)

View File

@@ -0,0 +1,2 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

View File

@@ -0,0 +1,152 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import itertools
from ... import __version__
from ..._config import get_config
from ..fixes import parse_version
class _HTMLDocumentationLinkMixin:
"""Mixin class allowing to generate a link to the API documentation.
This mixin relies on three attributes:
- `_doc_link_module`: it corresponds to the root module (e.g. `sklearn`). Using this
mixin, the default value is `sklearn`.
- `_doc_link_template`: it corresponds to the template used to generate the
link to the API documentation. Using this mixin, the default value is
`"https://scikit-learn.org/{version_url}/modules/generated/
{estimator_module}.{estimator_name}.html"`.
- `_doc_link_url_param_generator`: it corresponds to a function that generates the
parameters to be used in the template when the estimator module and name are not
sufficient.
The method :meth:`_get_doc_link` generates the link to the API documentation for a
given estimator.
This useful provides all the necessary states for
:func:`sklearn.utils.estimator_html_repr` to generate a link to the API
documentation for the estimator HTML diagram.
Examples
--------
If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
then you can override them and provide a method to generate the URL parameters:
>>> from sklearn.base import BaseEstimator
>>> doc_link_template = "https://address.local/{single_param}.html"
>>> def url_param_generator(estimator):
... return {"single_param": estimator.__class__.__name__}
>>> class MyEstimator(BaseEstimator):
... # use "builtins" since it is the associated module when declaring
... # the class in a docstring
... _doc_link_module = "builtins"
... _doc_link_template = doc_link_template
... _doc_link_url_param_generator = url_param_generator
>>> estimator = MyEstimator()
>>> estimator._get_doc_link()
'https://address.local/MyEstimator.html'
If instead of overriding the attributes inside the class definition, you want to
override a class instance, you can use `types.MethodType` to bind the method to the
instance:
>>> import types
>>> estimator = BaseEstimator()
>>> estimator._doc_link_template = doc_link_template
>>> estimator._doc_link_url_param_generator = types.MethodType(
... url_param_generator, estimator)
>>> estimator._get_doc_link()
'https://address.local/BaseEstimator.html'
"""
_doc_link_module = "sklearn"
_doc_link_url_param_generator = None
@property
def _doc_link_template(self):
sklearn_version = parse_version(__version__)
if sklearn_version.dev is None:
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"
else:
version_url = "dev"
return getattr(
self,
"__doc_link_template",
(
f"https://scikit-learn.org/{version_url}/modules/generated/"
"{estimator_module}.{estimator_name}.html"
),
)
@_doc_link_template.setter
def _doc_link_template(self, value):
setattr(self, "__doc_link_template", value)
def _get_doc_link(self):
"""Generates a link to the API documentation for a given estimator.
This method generates the link to the estimator's documentation page
by using the template defined by the attribute `_doc_link_template`.
Returns
-------
url : str
The URL to the API documentation for this estimator. If the estimator does
not belong to module `_doc_link_module`, the empty string (i.e. `""`) is
returned.
"""
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
return ""
if self._doc_link_url_param_generator is None:
estimator_name = self.__class__.__name__
# Construct the estimator's module name, up to the first private submodule.
# This works because in scikit-learn all public estimators are exposed at
# that level, even if they actually live in a private sub-module.
estimator_module = ".".join(
itertools.takewhile(
lambda part: not part.startswith("_"),
self.__class__.__module__.split("."),
)
)
return self._doc_link_template.format(
estimator_module=estimator_module, estimator_name=estimator_name
)
return self._doc_link_template.format(**self._doc_link_url_param_generator())
class ReprHTMLMixin:
"""Mixin to handle consistently the HTML representation.
When inheriting from this class, you need to define an attribute `_html_repr`
which is a callable that returns the HTML representation to be shown.
"""
@property
def _repr_html_(self):
"""HTML representation of estimator.
This is redundant with the logic of `_repr_mimebundle_`. The latter
should be favored in the long term, `_repr_html_` is only
implemented for consumers who do not interpret `_repr_mimbundle_`.
"""
if get_config()["display"] != "diagram":
raise AttributeError(
"_repr_html_ is only defined when the "
"'display' configuration option is set to "
"'diagram'"
)
return self._repr_html_inner
def _repr_html_inner(self):
"""This function is returned by the @property `_repr_html_` to make
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
on `get_config()["display"]`.
"""
return self._html_repr()
def _repr_mimebundle_(self, **kwargs):
"""Mime bundle used by jupyter kernels to display estimator"""
output = {"text/plain": repr(self)}
if get_config()["display"] == "diagram":
output["text/html"] = self._html_repr()
return output

View File

@@ -0,0 +1,413 @@
#$id {
/* Definition of color scheme common for light and dark mode */
--sklearn-color-text: #000;
--sklearn-color-text-muted: #666;
--sklearn-color-line: gray;
/* Definition of color scheme for unfitted estimators */
--sklearn-color-unfitted-level-0: #fff5e6;
--sklearn-color-unfitted-level-1: #f6e4d2;
--sklearn-color-unfitted-level-2: #ffe0b3;
--sklearn-color-unfitted-level-3: chocolate;
/* Definition of color scheme for fitted estimators */
--sklearn-color-fitted-level-0: #f0f8ff;
--sklearn-color-fitted-level-1: #d4ebff;
--sklearn-color-fitted-level-2: #b3dbfd;
--sklearn-color-fitted-level-3: cornflowerblue;
/* Specific color for light theme */
--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));
--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));
--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));
--sklearn-color-icon: #696969;
@media (prefers-color-scheme: dark) {
/* Redefinition of color scheme for dark theme */
--sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));
--sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));
--sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));
--sklearn-color-icon: #878787;
}
}
#$id {
color: var(--sklearn-color-text);
}
#$id pre {
padding: 0;
}
#$id input.sk-hidden--visually {
border: 0;
clip: rect(1px 1px 1px 1px);
clip: rect(1px, 1px, 1px, 1px);
height: 1px;
margin: -1px;
overflow: hidden;
padding: 0;
position: absolute;
width: 1px;
}
#$id div.sk-dashed-wrapped {
border: 1px dashed var(--sklearn-color-line);
margin: 0 0.4em 0.5em 0.4em;
box-sizing: border-box;
padding-bottom: 0.4em;
background-color: var(--sklearn-color-background);
}
#$id div.sk-container {
/* jupyter's `normalize.less` sets `[hidden] { display: none; }`
but bootstrap.min.css set `[hidden] { display: none !important; }`
so we also need the `!important` here to be able to override the
default hidden behavior on the sphinx rendered scikit-learn.org.
See: https://github.com/scikit-learn/scikit-learn/issues/21755 */
display: inline-block !important;
position: relative;
}
#$id div.sk-text-repr-fallback {
display: none;
}
div.sk-parallel-item,
div.sk-serial,
div.sk-item {
/* draw centered vertical line to link estimators */
background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));
background-size: 2px 100%;
background-repeat: no-repeat;
background-position: center center;
}
/* Parallel-specific style estimator block */
#$id div.sk-parallel-item::after {
content: "";
width: 100%;
border-bottom: 2px solid var(--sklearn-color-text-on-default-background);
flex-grow: 1;
}
#$id div.sk-parallel {
display: flex;
align-items: stretch;
justify-content: center;
background-color: var(--sklearn-color-background);
position: relative;
}
#$id div.sk-parallel-item {
display: flex;
flex-direction: column;
}
#$id div.sk-parallel-item:first-child::after {
align-self: flex-end;
width: 50%;
}
#$id div.sk-parallel-item:last-child::after {
align-self: flex-start;
width: 50%;
}
#$id div.sk-parallel-item:only-child::after {
width: 0;
}
/* Serial-specific style estimator block */
#$id div.sk-serial {
display: flex;
flex-direction: column;
align-items: center;
background-color: var(--sklearn-color-background);
padding-right: 1em;
padding-left: 1em;
}
/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is
clickable and can be expanded/collapsed.
- Pipeline and ColumnTransformer use this feature and define the default style
- Estimators will overwrite some part of the style using the `sk-estimator` class
*/
/* Pipeline and ColumnTransformer style (default) */
#$id div.sk-toggleable {
/* Default theme specific background. It is overwritten whether we have a
specific estimator or a Pipeline/ColumnTransformer */
background-color: var(--sklearn-color-background);
}
/* Toggleable label */
#$id label.sk-toggleable__label {
cursor: pointer;
display: flex;
width: 100%;
margin-bottom: 0;
padding: 0.5em;
box-sizing: border-box;
text-align: center;
align-items: start;
justify-content: space-between;
gap: 0.5em;
}
#$id label.sk-toggleable__label .caption {
font-size: 0.6rem;
font-weight: lighter;
color: var(--sklearn-color-text-muted);
}
#$id label.sk-toggleable__label-arrow:before {
/* Arrow on the left of the label */
content: "▸";
float: left;
margin-right: 0.25em;
color: var(--sklearn-color-icon);
}
#$id label.sk-toggleable__label-arrow:hover:before {
color: var(--sklearn-color-text);
}
/* Toggleable content - dropdown */
#$id div.sk-toggleable__content {
display: none;
text-align: left;
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-0);
}
#$id div.sk-toggleable__content.fitted {
/* fitted */
background-color: var(--sklearn-color-fitted-level-0);
}
#$id div.sk-toggleable__content pre {
margin: 0.2em;
border-radius: 0.25em;
color: var(--sklearn-color-text);
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-0);
}
#$id div.sk-toggleable__content.fitted pre {
/* unfitted */
background-color: var(--sklearn-color-fitted-level-0);
}
#$id input.sk-toggleable__control:checked~div.sk-toggleable__content {
/* Expand drop-down */
display: block;
width: 100%;
overflow: visible;
}
#$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {
content: "▾";
}
/* Pipeline/ColumnTransformer-specific style */
#$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {
color: var(--sklearn-color-text);
background-color: var(--sklearn-color-unfitted-level-2);
}
#$id div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {
background-color: var(--sklearn-color-fitted-level-2);
}
/* Estimator-specific style */
/* Colorize estimator box */
#$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-2);
}
#$id div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {
/* fitted */
background-color: var(--sklearn-color-fitted-level-2);
}
#$id div.sk-label label.sk-toggleable__label,
#$id div.sk-label label {
/* The background is the default theme color */
color: var(--sklearn-color-text-on-default-background);
}
/* On hover, darken the color of the background */
#$id div.sk-label:hover label.sk-toggleable__label {
color: var(--sklearn-color-text);
background-color: var(--sklearn-color-unfitted-level-2);
}
/* Label box, darken color on hover, fitted */
#$id div.sk-label.fitted:hover label.sk-toggleable__label.fitted {
color: var(--sklearn-color-text);
background-color: var(--sklearn-color-fitted-level-2);
}
/* Estimator label */
#$id div.sk-label label {
font-family: monospace;
font-weight: bold;
display: inline-block;
line-height: 1.2em;
}
#$id div.sk-label-container {
text-align: center;
}
/* Estimator-specific */
#$id div.sk-estimator {
font-family: monospace;
border: 1px dotted var(--sklearn-color-border-box);
border-radius: 0.25em;
box-sizing: border-box;
margin-bottom: 0.5em;
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-0);
}
#$id div.sk-estimator.fitted {
/* fitted */
background-color: var(--sklearn-color-fitted-level-0);
}
/* on hover */
#$id div.sk-estimator:hover {
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-2);
}
#$id div.sk-estimator.fitted:hover {
/* fitted */
background-color: var(--sklearn-color-fitted-level-2);
}
/* Specification for estimator info (e.g. "i" and "?") */
/* Common style for "i" and "?" */
.sk-estimator-doc-link,
a:link.sk-estimator-doc-link,
a:visited.sk-estimator-doc-link {
float: right;
font-size: smaller;
line-height: 1em;
font-family: monospace;
background-color: var(--sklearn-color-background);
border-radius: 1em;
height: 1em;
width: 1em;
text-decoration: none !important;
margin-left: 0.5em;
text-align: center;
/* unfitted */
border: var(--sklearn-color-unfitted-level-1) 1pt solid;
color: var(--sklearn-color-unfitted-level-1);
}
.sk-estimator-doc-link.fitted,
a:link.sk-estimator-doc-link.fitted,
a:visited.sk-estimator-doc-link.fitted {
/* fitted */
border: var(--sklearn-color-fitted-level-1) 1pt solid;
color: var(--sklearn-color-fitted-level-1);
}
/* On hover */
div.sk-estimator:hover .sk-estimator-doc-link:hover,
.sk-estimator-doc-link:hover,
div.sk-label-container:hover .sk-estimator-doc-link:hover,
.sk-estimator-doc-link:hover {
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-3);
color: var(--sklearn-color-background);
text-decoration: none;
}
div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,
.sk-estimator-doc-link.fitted:hover,
div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,
.sk-estimator-doc-link.fitted:hover {
/* fitted */
background-color: var(--sklearn-color-fitted-level-3);
color: var(--sklearn-color-background);
text-decoration: none;
}
/* Span, style for the box shown on hovering the info icon */
.sk-estimator-doc-link span {
display: none;
z-index: 9999;
position: relative;
font-weight: normal;
right: .2ex;
padding: .5ex;
margin: .5ex;
width: min-content;
min-width: 20ex;
max-width: 50ex;
color: var(--sklearn-color-text);
box-shadow: 2pt 2pt 4pt #999;
/* unfitted */
background: var(--sklearn-color-unfitted-level-0);
border: .5pt solid var(--sklearn-color-unfitted-level-3);
}
.sk-estimator-doc-link.fitted span {
/* fitted */
background: var(--sklearn-color-fitted-level-0);
border: var(--sklearn-color-fitted-level-3);
}
.sk-estimator-doc-link:hover span {
display: block;
}
/* "?"-specific style due to the `<a>` HTML tag */
#$id a.estimator_doc_link {
float: right;
font-size: 1rem;
line-height: 1em;
font-family: monospace;
background-color: var(--sklearn-color-background);
border-radius: 1rem;
height: 1rem;
width: 1rem;
text-decoration: none;
/* unfitted */
color: var(--sklearn-color-unfitted-level-1);
border: var(--sklearn-color-unfitted-level-1) 1pt solid;
}
#$id a.estimator_doc_link.fitted {
/* fitted */
border: var(--sklearn-color-fitted-level-1) 1pt solid;
color: var(--sklearn-color-fitted-level-1);
}
/* On hover */
#$id a.estimator_doc_link:hover {
/* unfitted */
background-color: var(--sklearn-color-unfitted-level-3);
color: var(--sklearn-color-background);
text-decoration: none;
}
#$id a.estimator_doc_link.fitted:hover {
/* fitted */
background-color: var(--sklearn-color-fitted-level-3);
}

View File

@@ -0,0 +1,42 @@
function copyToClipboard(text, element) {
// Get the parameter prefix from the closest toggleable content
const toggleableContent = element.closest('.sk-toggleable__content');
const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';
const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text;
const originalStyle = element.style;
const computedStyle = window.getComputedStyle(element);
const originalWidth = computedStyle.width;
const originalHTML = element.innerHTML.replace('Copied!', '');
navigator.clipboard.writeText(fullParamName)
.then(() => {
element.style.width = originalWidth;
element.style.color = 'green';
element.innerHTML = "Copied!";
setTimeout(() => {
element.innerHTML = originalHTML;
element.style = originalStyle;
}, 2000);
})
.catch(err => {
console.error('Failed to copy:', err);
element.style.color = 'red';
element.innerHTML = "Failed!";
setTimeout(() => {
element.innerHTML = originalHTML;
element.style = originalStyle;
}, 2000);
});
return false;
}
document.querySelectorAll('.fa-regular.fa-copy').forEach(function(element) {
const toggleableContent = element.closest('.sk-toggleable__content');
const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';
const paramName = element.parentElement.nextElementSibling.textContent.trim();
const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName;
element.setAttribute('title', fullParamName);
});

View File

@@ -0,0 +1,497 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import html
from contextlib import closing
from inspect import isclass
from io import StringIO
from pathlib import Path
from string import Template
from ... import config_context
class _IDCounter:
"""Generate sequential ids with a prefix."""
def __init__(self, prefix):
self.prefix = prefix
self.count = 0
def get_id(self):
self.count += 1
return f"{self.prefix}-{self.count}"
def _get_css_style():
estimator_css_file = Path(__file__).parent / "estimator.css"
params_css_file = Path(__file__).parent / "params.css"
estimator_css = estimator_css_file.read_text(encoding="utf-8")
params_css = params_css_file.read_text(encoding="utf-8")
return f"{estimator_css}\n{params_css}"
_CONTAINER_ID_COUNTER = _IDCounter("sk-container-id")
_ESTIMATOR_ID_COUNTER = _IDCounter("sk-estimator-id")
_CSS_STYLE = _get_css_style()
class _VisualBlock:
"""HTML Representation of Estimator
Parameters
----------
kind : {'serial', 'parallel', 'single'}
kind of HTML block
estimators : list of estimators or `_VisualBlock`s or a single estimator
If kind != 'single', then `estimators` is a list of
estimators.
If kind == 'single', then `estimators` is a single estimator.
names : list of str, default=None
If kind != 'single', then `names` corresponds to estimators.
If kind == 'single', then `names` is a single string corresponding to
the single estimator.
name_details : list of str, str, or None, default=None
If kind != 'single', then `name_details` corresponds to `names`.
If kind == 'single', then `name_details` is a single string
corresponding to the single estimator.
name_caption : str, default=None
The caption below the name. `None` stands for no caption.
Only active when kind == 'single'.
doc_link_label : str, default=None
The label for the documentation link. If provided, the label would be
"Documentation for {doc_link_label}". Otherwise it will look for `names`.
Only active when kind == 'single'.
dash_wrapped : bool, default=True
If true, wrapped HTML element will be wrapped with a dashed border.
Only active when kind != 'single'.
"""
def __init__(
self,
kind,
estimators,
*,
names=None,
name_details=None,
name_caption=None,
doc_link_label=None,
dash_wrapped=True,
):
self.kind = kind
self.estimators = estimators
self.dash_wrapped = dash_wrapped
self.name_caption = name_caption
self.doc_link_label = doc_link_label
if self.kind in ("parallel", "serial"):
if names is None:
names = (None,) * len(estimators)
if name_details is None:
name_details = (None,) * len(estimators)
self.names = names
self.name_details = name_details
def _sk_visual_block_(self):
return self
def _write_label_html(
out,
params,
name,
name_details,
name_caption=None,
doc_link_label=None,
outer_class="sk-label-container",
inner_class="sk-label",
checked=False,
doc_link="",
is_fitted_css_class="",
is_fitted_icon="",
param_prefix="",
):
"""Write labeled html with or without a dropdown with named details.
Parameters
----------
out : file-like object
The file to write the HTML representation to.
params: str
If estimator has `get_params` method, this is the HTML representation
of the estimator's parameters and their values. When the estimator
does not have `get_params`, it is an empty string.
name : str
The label for the estimator. It corresponds either to the estimator class name
for a simple estimator or in the case of a `Pipeline` and `ColumnTransformer`,
it corresponds to the name of the step.
name_details : str
The details to show as content in the dropdown part of the toggleable label. It
can contain information such as non-default parameters or column information for
`ColumnTransformer`.
name_caption : str, default=None
The caption below the name. If `None`, no caption will be created.
doc_link_label : str, default=None
The label for the documentation link. If provided, the label would be
"Documentation for {doc_link_label}". Otherwise it will look for `name`.
outer_class : {"sk-label-container", "sk-item"}, default="sk-label-container"
The CSS class for the outer container.
inner_class : {"sk-label", "sk-estimator"}, default="sk-label"
The CSS class for the inner container.
checked : bool, default=False
Whether the dropdown is folded or not. With a single estimator, we intend to
unfold the content.
doc_link : str, default=""
The link to the documentation for the estimator. If an empty string, no link is
added to the diagram. This can be generated for an estimator if it uses the
`_HTMLDocumentationLinkMixin`.
is_fitted_css_class : {"", "fitted"}
The CSS class to indicate whether or not the estimator is fitted. The
empty string means that the estimator is not fitted and "fitted" means that the
estimator is fitted.
is_fitted_icon : str, default=""
The HTML representation to show the fitted information in the diagram. An empty
string means that no information is shown.
param_prefix : str, default=""
The prefix to prepend to parameter names for nested estimators.
"""
out.write(
f'<div class="{outer_class}"><div'
f' class="{inner_class} {is_fitted_css_class} sk-toggleable">'
)
name = html.escape(name)
if name_details is not None:
name_details = html.escape(str(name_details))
checked_str = "checked" if checked else ""
est_id = _ESTIMATOR_ID_COUNTER.get_id()
if doc_link:
doc_label = "<span>Online documentation</span>"
if doc_link_label is not None:
doc_label = f"<span>Documentation for {doc_link_label}</span>"
elif name is not None:
doc_label = f"<span>Documentation for {name}</span>"
doc_link = (
f'<a class="sk-estimator-doc-link {is_fitted_css_class}"'
f' rel="noreferrer" target="_blank" href="{doc_link}">?{doc_label}</a>'
)
name_caption_div = (
""
if name_caption is None
else f'<div class="caption">{html.escape(name_caption)}</div>'
)
name_caption_div = f"<div><div>{name}</div>{name_caption_div}</div>"
links_div = (
f"<div>{doc_link}{is_fitted_icon}</div>"
if doc_link or is_fitted_icon
else ""
)
label_html = (
f'<label for="{est_id}" class="sk-toggleable__label {is_fitted_css_class} '
f'sk-toggleable__label-arrow">{name_caption_div}{links_div}</label>'
)
fmt_str = (
f'<input class="sk-toggleable__control sk-hidden--visually" id="{est_id}" '
f'type="checkbox" {checked_str}>{label_html}<div '
f'class="sk-toggleable__content {is_fitted_css_class}" '
f'data-param-prefix="{html.escape(param_prefix)}">'
)
if params:
fmt_str = "".join([fmt_str, f"{params}</div>"])
elif name_details and ("Pipeline" not in name):
fmt_str = "".join([fmt_str, f"<pre>{name_details}</pre></div>"])
out.write(fmt_str)
else:
out.write(f"<label>{name}</label>")
out.write("</div></div>") # outer_class inner_class
def _get_visual_block(estimator):
"""Generate information about how to display an estimator."""
if hasattr(estimator, "_sk_visual_block_"):
try:
return estimator._sk_visual_block_()
except Exception:
return _VisualBlock(
"single",
estimator,
names=estimator.__class__.__name__,
name_details=str(estimator),
)
if isinstance(estimator, str):
return _VisualBlock(
"single", estimator, names=estimator, name_details=estimator
)
elif estimator is None:
return _VisualBlock("single", estimator, names="None", name_details="None")
# check if estimator looks like a meta estimator (wraps estimators)
if hasattr(estimator, "get_params") and not isclass(estimator):
estimators = [
(key, est)
for key, est in estimator.get_params(deep=False).items()
if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
]
if estimators:
return _VisualBlock(
"parallel",
[est for _, est in estimators],
names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
name_details=[str(est) for _, est in estimators],
)
return _VisualBlock(
"single",
estimator,
names=estimator.__class__.__name__,
name_details=str(estimator),
)
def _write_estimator_html(
out,
estimator,
estimator_label,
estimator_label_details,
is_fitted_css_class,
is_fitted_icon="",
first_call=False,
param_prefix="",
):
"""Write estimator to html in serial, parallel, or by itself (single).
For multiple estimators, this function is called recursively.
Parameters
----------
out : file-like object
The file to write the HTML representation to.
estimator : estimator object
The estimator to visualize.
estimator_label : str
The label for the estimator. It corresponds either to the estimator class name
for simple estimator or in the case of `Pipeline` and `ColumnTransformer`, it
corresponds to the name of the step.
estimator_label_details : str
The details to show as content in the dropdown part of the toggleable label.
It can contain information as non-default parameters or column information for
`ColumnTransformer`.
is_fitted_css_class : {"", "fitted"}
The CSS class to indicate whether or not the estimator is fitted or not. The
empty string means that the estimator is not fitted and "fitted" means that the
estimator is fitted.
is_fitted_icon : str, default=""
The HTML representation to show the fitted information in the diagram. An empty
string means that no information is shown. If the estimator to be shown is not
the first estimator (i.e. `first_call=False`), `is_fitted_icon` is always an
empty string.
first_call : bool, default=False
Whether this is the first time this function is called.
param_prefix : str, default=""
The prefix to prepend to parameter names for nested estimators.
For example, in a pipeline this might be "pipeline__stepname__".
"""
if first_call:
est_block = _get_visual_block(estimator)
else:
is_fitted_icon = ""
with config_context(print_changed_only=True):
est_block = _get_visual_block(estimator)
# `estimator` can also be an instance of `_VisualBlock`
if hasattr(estimator, "_get_doc_link"):
doc_link = estimator._get_doc_link()
else:
doc_link = ""
if est_block.kind in ("serial", "parallel"):
dashed_wrapped = first_call or est_block.dash_wrapped
dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
out.write(f'<div class="sk-item{dash_cls}">')
if estimator_label:
if hasattr(estimator, "get_params") and hasattr(
estimator, "_get_params_html"
):
params = estimator._get_params_html(deep=False)._repr_html_inner()
else:
params = ""
_write_label_html(
out,
params,
estimator_label,
estimator_label_details,
doc_link=doc_link,
is_fitted_css_class=is_fitted_css_class,
is_fitted_icon=is_fitted_icon,
param_prefix=param_prefix,
)
kind = est_block.kind
out.write(f'<div class="sk-{kind}">')
est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
for est, name, name_details in est_infos:
# Build the parameter prefix for nested estimators
if param_prefix and hasattr(name, "split"):
# If we already have a prefix, append the new component
new_prefix = f"{param_prefix}{name.split(':')[0]}__"
elif hasattr(name, "split"):
# If this is the first level, start the prefix
new_prefix = f"{name.split(':')[0]}__" if name else ""
else:
new_prefix = param_prefix
if kind == "serial":
_write_estimator_html(
out,
est,
name,
name_details,
is_fitted_css_class=is_fitted_css_class,
param_prefix=new_prefix,
)
else: # parallel
out.write('<div class="sk-parallel-item">')
# wrap element in a serial visualblock
serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
_write_estimator_html(
out,
serial_block,
name,
name_details,
is_fitted_css_class=is_fitted_css_class,
param_prefix=new_prefix,
)
out.write("</div>") # sk-parallel-item
out.write("</div></div>")
elif est_block.kind == "single":
if hasattr(estimator, "_get_params_html"):
params = estimator._get_params_html()._repr_html_inner()
else:
params = ""
_write_label_html(
out,
params,
est_block.names,
est_block.name_details,
est_block.name_caption,
est_block.doc_link_label,
outer_class="sk-item",
inner_class="sk-estimator",
checked=first_call,
doc_link=doc_link,
is_fitted_css_class=is_fitted_css_class,
is_fitted_icon=is_fitted_icon,
param_prefix=param_prefix,
)
def estimator_html_repr(estimator):
"""Build a HTML representation of an estimator.
Read more in the :ref:`User Guide <visualizing_composite_estimators>`.
Parameters
----------
estimator : estimator object
The estimator to visualize.
Returns
-------
html: str
HTML representation of estimator.
Examples
--------
>>> from sklearn.utils._repr_html.estimator import estimator_html_repr
>>> from sklearn.linear_model import LogisticRegression
>>> estimator_html_repr(LogisticRegression())
'<style>#sk-container-id...'
"""
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
if not hasattr(estimator, "fit"):
status_label = "<span>Not fitted</span>"
is_fitted_css_class = ""
else:
try:
check_is_fitted(estimator)
status_label = "<span>Fitted</span>"
is_fitted_css_class = "fitted"
except NotFittedError:
status_label = "<span>Not fitted</span>"
is_fitted_css_class = ""
is_fitted_icon = (
f'<span class="sk-estimator-doc-link {is_fitted_css_class}">'
f"i{status_label}</span>"
)
with closing(StringIO()) as out:
container_id = _CONTAINER_ID_COUNTER.get_id()
style_template = Template(_CSS_STYLE)
style_with_id = style_template.substitute(id=container_id)
estimator_str = str(estimator)
# The fallback message is shown by default and loading the CSS sets
# div.sk-text-repr-fallback to display: none to hide the fallback message.
#
# If the notebook is trusted, the CSS is loaded which hides the fallback
# message. If the notebook is not trusted, then the CSS is not loaded and the
# fallback message is shown by default.
#
# The reverse logic applies to HTML repr div.sk-container.
# div.sk-container is hidden by default and the loading the CSS displays it.
fallback_msg = (
"In a Jupyter environment, please rerun this cell to show the HTML"
" representation or trust the notebook. <br />On GitHub, the"
" HTML representation is unable to render, please try loading this page"
" with nbviewer.org."
)
html_template = (
f"<style>{style_with_id}</style>"
f"<body>"
f'<div id="{container_id}" class="sk-top-container">'
'<div class="sk-text-repr-fallback">'
f"<pre>{html.escape(estimator_str)}</pre><b>{fallback_msg}</b>"
"</div>"
'<div class="sk-container" hidden>'
)
out.write(html_template)
_write_estimator_html(
out,
estimator,
estimator.__class__.__name__,
estimator_str,
first_call=True,
is_fitted_css_class=is_fitted_css_class,
is_fitted_icon=is_fitted_icon,
)
with open(str(Path(__file__).parent / "estimator.js"), "r") as f:
script = f.read()
html_end = f"</div></div><script>{script}</script></body>"
out.write(html_end)
html_output = out.getvalue()
return html_output

View File

@@ -0,0 +1,63 @@
.estimator-table summary {
padding: .5rem;
font-family: monospace;
cursor: pointer;
}
.estimator-table details[open] {
padding-left: 0.1rem;
padding-right: 0.1rem;
padding-bottom: 0.3rem;
}
.estimator-table .parameters-table {
margin-left: auto !important;
margin-right: auto !important;
}
.estimator-table .parameters-table tr:nth-child(odd) {
background-color: #fff;
}
.estimator-table .parameters-table tr:nth-child(even) {
background-color: #f6f6f6;
}
.estimator-table .parameters-table tr:hover {
background-color: #e0e0e0;
}
.estimator-table table td {
border: 1px solid rgba(106, 105, 104, 0.232);
}
.user-set td {
color:rgb(255, 94, 0);
text-align: left;
}
.user-set td.value pre {
color:rgb(255, 94, 0) !important;
background-color: transparent !important;
}
.default td {
color: black;
text-align: left;
}
.user-set td i,
.default td i {
color: black;
}
.copy-paste-icon {
background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=);
background-repeat: no-repeat;
background-size: 14px 14px;
background-position: 0;
display: inline-block;
width: 14px;
height: 14px;
cursor: pointer;
}

View File

@@ -0,0 +1,83 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import html
import reprlib
from collections import UserDict
from sklearn.utils._repr_html.base import ReprHTMLMixin
def _read_params(name, value, non_default_params):
"""Categorizes parameters as 'default' or 'user-set' and formats their values.
Escapes or truncates parameter values for display safety and readability.
"""
r = reprlib.Repr()
r.maxlist = 2 # Show only first 2 items of lists
r.maxtuple = 1 # Show only first item of tuples
r.maxstring = 50 # Limit string length
cleaned_value = html.escape(r.repr(value))
param_type = "user-set" if name in non_default_params else "default"
return {"param_type": param_type, "param_name": name, "param_value": cleaned_value}
def _params_html_repr(params):
"""Generate HTML representation of estimator parameters.
Creates an HTML table with parameter names and values, wrapped in a
collapsible details element. Parameters are styled differently based
on whether they are default or user-set values.
"""
HTML_TEMPLATE = """
<div class="estimator-table">
<details>
<summary>Parameters</summary>
<table class="parameters-table">
<tbody>
{rows}
</tbody>
</table>
</details>
</div>
"""
ROW_TEMPLATE = """
<tr class="{param_type}">
<td><i class="copy-paste-icon"
onclick="copyToClipboard('{param_name}',
this.parentElement.nextElementSibling)"
></i></td>
<td class="param">{param_name}&nbsp;</td>
<td class="value">{param_value}</td>
</tr>
"""
rows = [
ROW_TEMPLATE.format(**_read_params(name, value, params.non_default))
for name, value in params.items()
]
return HTML_TEMPLATE.format(rows="\n".join(rows))
class ParamsDict(ReprHTMLMixin, UserDict):
"""Dictionary-like class to store and provide an HTML representation.
It builds an HTML structure to be used with Jupyter notebooks or similar
environments. It allows storing metadata to track non-default parameters.
Parameters
----------
params : dict, default=None
The original dictionary of parameters and their values.
non_default : tuple
The list of non-default parameters.
"""
_html_repr = _params_html_repr
def __init__(self, params=None, non_default=tuple()):
super().__init__(params or {})
self.non_default = non_default

View File

@@ -0,0 +1,616 @@
import html
import locale
import re
import types
from contextlib import closing
from functools import partial
from io import StringIO
from unittest.mock import patch
import numpy as np
import pytest
from sklearn import config_context
from sklearn.base import BaseEstimator
from sklearn.cluster import AgglomerativeClustering, Birch
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.ensemble import StackingClassifier, StackingRegressor, VotingClassifier
from sklearn.feature_selection import SelectPercentile
from sklearn.gaussian_process.kernels import ExpSineSquared
from sklearn.impute import SimpleImputer
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from sklearn.multiclass import OneVsOneClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import FeatureUnion, Pipeline, make_pipeline
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, StandardScaler
from sklearn.svm import LinearSVC, LinearSVR
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._repr_html.base import _HTMLDocumentationLinkMixin
from sklearn.utils._repr_html.estimator import (
_get_css_style,
_get_visual_block,
_write_label_html,
estimator_html_repr,
)
from sklearn.utils.fixes import parse_version
def dummy_function(x, y):
return x + y # pragma: nocover
@pytest.mark.parametrize("checked", [True, False])
def test_write_label_html(checked):
# Test checking logic and labeling
name = "LogisticRegression"
params = ""
tool_tip = "hello-world"
with closing(StringIO()) as out:
_write_label_html(out, params, name, tool_tip, checked=checked)
html_label = out.getvalue()
p = (
r'<label for="sk-estimator-id-[0-9]*"'
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
r"<div><div>LogisticRegression</div></div>"
)
re_compiled = re.compile(p)
assert re_compiled.search(html_label)
assert html_label.startswith('<div class="sk-label-container">')
assert "<pre>hello-world</pre>" in html_label
if checked:
assert "checked>" in html_label
@pytest.mark.parametrize("est", ["passthrough", "drop", None])
def test_get_visual_block_single_str_none(est):
# Test estimators that are represented by strings
est_html_info = _get_visual_block(est)
assert est_html_info.kind == "single"
assert est_html_info.estimators == est
assert est_html_info.names == str(est)
assert est_html_info.name_details == str(est)
def test_get_visual_block_single_estimator():
est = LogisticRegression(C=10.0)
est_html_info = _get_visual_block(est)
assert est_html_info.kind == "single"
assert est_html_info.estimators == est
assert est_html_info.names == est.__class__.__name__
assert est_html_info.name_details == str(est)
def test_get_visual_block_pipeline():
pipe = Pipeline(
[
("imputer", SimpleImputer()),
("do_nothing", "passthrough"),
("do_nothing_more", None),
("classifier", LogisticRegression()),
]
)
est_html_info = _get_visual_block(pipe)
assert est_html_info.kind == "serial"
assert est_html_info.estimators == tuple(step[1] for step in pipe.steps)
assert est_html_info.names == [
"imputer: SimpleImputer",
"do_nothing: passthrough",
"do_nothing_more: passthrough",
"classifier: LogisticRegression",
]
assert est_html_info.name_details == [str(est) for _, est in pipe.steps]
def test_get_visual_block_feature_union():
f_union = FeatureUnion([("pca", PCA()), ("svd", TruncatedSVD())])
est_html_info = _get_visual_block(f_union)
assert est_html_info.kind == "parallel"
assert est_html_info.names == ("pca", "svd")
assert est_html_info.estimators == tuple(
trans[1] for trans in f_union.transformer_list
)
assert est_html_info.name_details == (None, None)
def test_get_visual_block_voting():
clf = VotingClassifier(
[("log_reg", LogisticRegression()), ("mlp", MLPClassifier())]
)
est_html_info = _get_visual_block(clf)
assert est_html_info.kind == "parallel"
assert est_html_info.estimators == tuple(trans[1] for trans in clf.estimators)
assert est_html_info.names == ("log_reg", "mlp")
assert est_html_info.name_details == (None, None)
def test_get_visual_block_column_transformer():
ct = ColumnTransformer(
[("pca", PCA(), ["num1", "num2"]), ("svd", TruncatedSVD, [0, 3])]
)
est_html_info = _get_visual_block(ct)
assert est_html_info.kind == "parallel"
assert est_html_info.estimators == tuple(trans[1] for trans in ct.transformers)
assert est_html_info.names == ("pca", "svd")
assert est_html_info.name_details == (["num1", "num2"], [0, 3])
def test_estimator_html_repr_an_empty_pipeline():
"""Check that the representation of an empty Pipeline does not fail.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/30197
"""
empty_pipeline = Pipeline([])
estimator_html_repr(empty_pipeline)
def test_estimator_html_repr_pipeline():
num_trans = Pipeline(
steps=[("pass", "passthrough"), ("imputer", SimpleImputer(strategy="median"))]
)
cat_trans = Pipeline(
steps=[
("imputer", SimpleImputer(strategy="constant", missing_values="empty")),
("one-hot", OneHotEncoder(drop="first")),
]
)
preprocess = ColumnTransformer(
[
("num", num_trans, ["a", "b", "c", "d", "e"]),
("cat", cat_trans, [0, 1, 2, 3]),
]
)
feat_u = FeatureUnion(
[
("pca", PCA(n_components=1)),
(
"tsvd",
Pipeline(
[
("first", TruncatedSVD(n_components=3)),
("select", SelectPercentile()),
]
),
),
]
)
clf = VotingClassifier(
[
("lr", LogisticRegression(solver="lbfgs", random_state=1)),
("mlp", MLPClassifier(alpha=0.001)),
]
)
pipe = Pipeline(
[("preprocessor", preprocess), ("feat_u", feat_u), ("classifier", clf)]
)
html_output = estimator_html_repr(pipe)
# top level estimators show estimator with changes
assert html.escape(str(pipe)) in html_output
for _, est in pipe.steps:
assert html.escape(str(est))[:44] in html_output
# low level estimators do not show changes
with config_context(print_changed_only=True):
assert html.escape(str(num_trans["pass"])) in html_output
assert "<div><div>passthrough</div></div></label>" in html_output
assert html.escape(str(num_trans["imputer"])) in html_output
for _, _, cols in preprocess.transformers:
assert f"<pre>{html.escape(str(cols))}</pre>" in html_output
# feature union
for name, _ in feat_u.transformer_list:
assert f"<label>{html.escape(name)}</label>" in html_output
pca = feat_u.transformer_list[0][1]
assert html.escape(str(pca)) in html_output
tsvd = feat_u.transformer_list[1][1]
first = tsvd["first"]
select = tsvd["select"]
assert html.escape(str(first)) in html_output
assert html.escape(str(select)) in html_output
# voting classifier
for name, est in clf.estimators:
assert html.escape(name) in html_output
assert html.escape(str(est)) in html_output
# verify that prefers-color-scheme is implemented
assert "prefers-color-scheme" in html_output
@pytest.mark.parametrize("final_estimator", [None, LinearSVC()])
def test_stacking_classifier(final_estimator):
estimators = [
("mlp", MLPClassifier(alpha=0.001)),
("tree", DecisionTreeClassifier()),
]
clf = StackingClassifier(estimators=estimators, final_estimator=final_estimator)
html_output = estimator_html_repr(clf)
assert html.escape(str(clf)) in html_output
# If final_estimator's default changes from LogisticRegression
# this should be updated
if final_estimator is None:
assert "LogisticRegression" in html_output
else:
assert final_estimator.__class__.__name__ in html_output
@pytest.mark.parametrize("final_estimator", [None, LinearSVR()])
def test_stacking_regressor(final_estimator):
reg = StackingRegressor(
estimators=[("svr", LinearSVR())], final_estimator=final_estimator
)
html_output = estimator_html_repr(reg)
assert html.escape(str(reg.estimators[0][0])) in html_output
p = (
r'<label for="sk-estimator-id-[0-9]*"'
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
r"<div><div>LinearSVR</div></div>"
)
re_compiled = re.compile(p)
assert re_compiled.search(html_output)
if final_estimator is None:
p = (
r'<label for="sk-estimator-id-[0-9]*"'
r' class="sk-toggleable__label (fitted)? sk-toggleable__label-arrow">'
r"<div><div>RidgeCV</div></div>"
)
re_compiled = re.compile(p)
assert re_compiled.search(html_output)
else:
assert html.escape(final_estimator.__class__.__name__) in html_output
def test_birch_duck_typing_meta():
# Test duck typing meta estimators with Birch
birch = Birch(n_clusters=AgglomerativeClustering(n_clusters=3))
html_output = estimator_html_repr(birch)
# inner estimators do not show changes
with config_context(print_changed_only=True):
assert f"<pre>{html.escape(str(birch.n_clusters))}" in html_output
p = r"<div><div>AgglomerativeClustering</div></div><div>.+</div></label>"
re_compiled = re.compile(p)
assert re_compiled.search(html_output)
# outer estimator contains all changes
assert f"<pre>{html.escape(str(birch))}" in html_output
def test_ovo_classifier_duck_typing_meta():
# Test duck typing metaestimators with OVO
ovo = OneVsOneClassifier(LinearSVC(penalty="l1"))
html_output = estimator_html_repr(ovo)
# inner estimators do not show changes
with config_context(print_changed_only=True):
assert f"<pre>{html.escape(str(ovo.estimator))}" in html_output
# regex to match the start of the tag
p = (
r'<label for="sk-estimator-id-[0-9]*" '
r'class="sk-toggleable__label sk-toggleable__label-arrow">'
r"<div><div>LinearSVC</div></div>"
)
re_compiled = re.compile(p)
assert re_compiled.search(html_output)
# outer estimator
assert f"<pre>{html.escape(str(ovo))}" in html_output
def test_duck_typing_nested_estimator():
# Test duck typing metaestimators with random search
kernel_ridge = KernelRidge(kernel=ExpSineSquared())
param_distributions = {"alpha": [1, 2]}
kernel_ridge_tuned = RandomizedSearchCV(
kernel_ridge,
param_distributions=param_distributions,
)
html_output = estimator_html_repr(kernel_ridge_tuned)
assert "<div><div>estimator: KernelRidge</div></div></label>" in html_output
@pytest.mark.parametrize("print_changed_only", [True, False])
def test_one_estimator_print_change_only(print_changed_only):
pca = PCA(n_components=10)
with config_context(print_changed_only=print_changed_only):
pca_repr = html.escape(str(pca))
html_output = estimator_html_repr(pca)
assert pca_repr in html_output
def test_fallback_exists():
"""Check that repr fallback is in the HTML."""
pca = PCA(n_components=10)
html_output = estimator_html_repr(pca)
assert (
f'<div class="sk-text-repr-fallback"><pre>{html.escape(str(pca))}'
in html_output
)
def test_show_arrow_pipeline():
"""Show arrow in pipeline for top level in pipeline"""
pipe = Pipeline([("scale", StandardScaler()), ("log_Reg", LogisticRegression())])
html_output = estimator_html_repr(pipe)
assert (
'class="sk-toggleable__label sk-toggleable__label-arrow">'
"<div><div>Pipeline</div></div>" in html_output
)
def test_invalid_parameters_in_stacking():
"""Invalidate stacking configuration uses default repr.
Non-regression test for #24009.
"""
stacker = StackingClassifier(estimators=[])
html_output = estimator_html_repr(stacker)
assert html.escape(str(stacker)) in html_output
def test_estimator_get_params_return_cls():
"""Check HTML repr works where a value in get_params is a class."""
class MyEstimator:
def get_params(self, deep=False):
return {"inner_cls": LogisticRegression}
est = MyEstimator()
assert "MyEstimator" in estimator_html_repr(est)
def test_estimator_html_repr_unfitted_vs_fitted():
"""Check that we have the information that the estimator is fitted or not in the
HTML representation.
"""
class MyEstimator(BaseEstimator):
def fit(self, X, y):
self.fitted_ = True
return self
X, y = load_iris(return_X_y=True)
estimator = MyEstimator()
assert "<span>Not fitted</span>" in estimator_html_repr(estimator)
estimator.fit(X, y)
assert "<span>Fitted</span>" in estimator_html_repr(estimator)
@pytest.mark.parametrize(
"estimator",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), slice(0, 3))),
LogisticRegression(),
),
],
)
def test_estimator_html_repr_fitted_icon(estimator):
"""Check that we are showing the fitted status icon only once."""
pattern = '<span class="sk-estimator-doc-link ">i<span>Not fitted</span></span>'
assert estimator_html_repr(estimator).count(pattern) == 1
X, y = load_iris(return_X_y=True)
estimator.fit(X, y)
pattern = '<span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span>'
assert estimator_html_repr(estimator).count(pattern) == 1
@pytest.mark.parametrize("mock_version", ["1.3.0.dev0", "1.3.0"])
def test_html_documentation_link_mixin_sklearn(mock_version):
"""Check the behaviour of the `_HTMLDocumentationLinkMixin` class for scikit-learn
default.
"""
# mock the `__version__` where the mixin is located
with patch("sklearn.utils._repr_html.base.__version__", mock_version):
mixin = _HTMLDocumentationLinkMixin()
assert mixin._doc_link_module == "sklearn"
sklearn_version = parse_version(mock_version)
# we need to parse the version manually to be sure that this test is passing in
# other branches than `main` (that is "dev").
if sklearn_version.dev is None:
version = f"{sklearn_version.major}.{sklearn_version.minor}"
else:
version = "dev"
assert (
mixin._doc_link_template
== f"https://scikit-learn.org/{version}/modules/generated/"
"{estimator_module}.{estimator_name}.html"
)
assert (
mixin._get_doc_link()
== f"https://scikit-learn.org/{version}/modules/generated/"
"sklearn.utils._HTMLDocumentationLinkMixin.html"
)
@pytest.mark.parametrize(
"module_path,expected_module",
[
("prefix.mymodule", "prefix.mymodule"),
("prefix._mymodule", "prefix"),
("prefix.mypackage._mymodule", "prefix.mypackage"),
("prefix.mypackage._mymodule.submodule", "prefix.mypackage"),
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
],
)
def test_html_documentation_link_mixin_get_doc_link_instance(
module_path, expected_module
):
"""Check the behaviour of the `_get_doc_link` with various parameter."""
class FooBar(_HTMLDocumentationLinkMixin):
pass
FooBar.__module__ = module_path
est = FooBar()
# if we set `_doc_link`, then we expect to infer a module and name for the estimator
est._doc_link_module = "prefix"
est._doc_link_template = (
"https://website.com/{estimator_module}.{estimator_name}.html"
)
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
@pytest.mark.parametrize(
"module_path,expected_module",
[
("prefix.mymodule", "prefix.mymodule"),
("prefix._mymodule", "prefix"),
("prefix.mypackage._mymodule", "prefix.mypackage"),
("prefix.mypackage._mymodule.submodule", "prefix.mypackage"),
("prefix.mypackage.mymodule.submodule", "prefix.mypackage.mymodule.submodule"),
],
)
def test_html_documentation_link_mixin_get_doc_link_class(module_path, expected_module):
"""Check the behaviour of the `_get_doc_link` when `_doc_link_module` and
`_doc_link_template` are defined at the class level and not at the instance
level."""
class FooBar(_HTMLDocumentationLinkMixin):
_doc_link_module = "prefix"
_doc_link_template = (
"https://website.com/{estimator_module}.{estimator_name}.html"
)
FooBar.__module__ = module_path
est = FooBar()
assert est._get_doc_link() == f"https://website.com/{expected_module}.FooBar.html"
def test_html_documentation_link_mixin_get_doc_link_out_of_library():
"""Check the behaviour of the `_get_doc_link` with various parameter."""
mixin = _HTMLDocumentationLinkMixin()
# if the `_doc_link_module` does not refer to the root module of the estimator
# (here the mixin), then we should return an empty string.
mixin._doc_link_module = "xxx"
assert mixin._get_doc_link() == ""
def test_html_documentation_link_mixin_doc_link_url_param_generator_instance():
mixin = _HTMLDocumentationLinkMixin()
# we can bypass the generation by providing our own callable
mixin._doc_link_template = (
"https://website.com/{my_own_variable}.{another_variable}.html"
)
def url_param_generator(estimator):
return {
"my_own_variable": "value_1",
"another_variable": "value_2",
}
mixin._doc_link_url_param_generator = types.MethodType(url_param_generator, mixin)
assert mixin._get_doc_link() == "https://website.com/value_1.value_2.html"
def test_html_documentation_link_mixin_doc_link_url_param_generator_class():
# we can bypass the generation by providing our own callable
def url_param_generator(estimator):
return {
"my_own_variable": "value_1",
"another_variable": "value_2",
}
class FooBar(_HTMLDocumentationLinkMixin):
_doc_link_template = (
"https://website.com/{my_own_variable}.{another_variable}.html"
)
_doc_link_url_param_generator = url_param_generator
estimator = FooBar()
assert estimator._get_doc_link() == "https://website.com/value_1.value_2.html"
@pytest.fixture
def set_non_utf8_locale():
"""Pytest fixture to set non utf-8 locale during the test.
The locale is set to the original one after the test has run.
"""
try:
locale.setlocale(locale.LC_CTYPE, "C")
except locale.Error:
pytest.skip("'C' locale is not available on this OS")
yield
# Resets the locale to the original one. Python calls setlocale(LC_TYPE, "")
# at startup according to
# https://docs.python.org/3/library/locale.html#background-details-hints-tips-and-caveats.
# This assumes that no other locale changes have been made. For some reason,
# on some platforms, trying to restore locale with something like
# locale.setlocale(locale.LC_CTYPE, locale.getlocale()) raises a
# locale.Error: unsupported locale setting
locale.setlocale(locale.LC_CTYPE, "")
def test_non_utf8_locale(set_non_utf8_locale):
"""Checks that utf8 encoding is used when reading the CSS file.
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/27725
"""
_get_css_style()
@pytest.mark.parametrize(
"func, expected_name",
[
(lambda x: x + 1, html.escape("<lambda>")),
(dummy_function, "dummy_function"),
(partial(dummy_function, y=1), "dummy_function"),
(np.vectorize(partial(dummy_function, y=1)), re.escape("vectorize(...)")),
],
)
def test_function_transformer_show_caption(func, expected_name):
# Test that function name is shown as the name and "FunctionTransformer" is shown
# in the caption
ft = FunctionTransformer(func)
html_output = estimator_html_repr(ft)
p = (
r'<label for="sk-estimator-id-[0-9]*" class="sk-toggleable__label fitted '
rf'sk-toggleable__label-arrow"><div><div>{expected_name}</div>'
r'<div class="caption">FunctionTransformer</div></div>'
)
re_compiled = re.compile(p)
assert re_compiled.search(html_output)
def test_estimator_html_repr_table():
"""Check that we add the table of parameters in the HTML representation."""
est = LogisticRegression(C=10.0, fit_intercept=False)
assert "parameters-table" in estimator_html_repr(est)

View File

@@ -0,0 +1,74 @@
import pytest
from sklearn import config_context
from sklearn.utils._repr_html.params import ParamsDict, _params_html_repr, _read_params
def test_params_dict_content():
"""Check the behavior of the ParamsDict class."""
params = ParamsDict({"a": 1, "b": 2})
assert params["a"] == 1
assert params["b"] == 2
assert params.non_default == ()
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
assert params["a"] == 1
assert params["b"] == 2
assert params.non_default == ("a",)
def test_params_dict_repr_html_():
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
out = params._repr_html_()
assert "<summary>Parameters</summary>" in out
with config_context(display="text"):
msg = "_repr_html_ is only defined when"
with pytest.raises(AttributeError, match=msg):
params._repr_html_()
def test_params_dict_repr_mimebundle():
params = ParamsDict({"a": 1, "b": 2}, non_default=("a",))
out = params._repr_mimebundle_()
assert "text/plain" in out
assert "text/html" in out
assert "<summary>Parameters</summary>" in out["text/html"]
assert out["text/plain"] == "{'a': 1, 'b': 2}"
with config_context(display="text"):
out = params._repr_mimebundle_()
assert "text/plain" in out
assert "text/html" not in out
def test_read_params():
"""Check the behavior of the `_read_params` function."""
out = _read_params("a", 1, tuple())
assert out["param_type"] == "default"
assert out["param_name"] == "a"
assert out["param_value"] == "1"
# check non-default parameters
out = _read_params("a", 1, ("a",))
assert out["param_type"] == "user-set"
assert out["param_name"] == "a"
assert out["param_value"] == "1"
# check that we escape html tags
tag_injection = "<script>alert('xss')</script>"
out = _read_params("a", tag_injection, tuple())
assert (
out["param_value"]
== "&quot;&lt;script&gt;alert(&#x27;xss&#x27;)&lt;/script&gt;&quot;"
)
assert out["param_name"] == "a"
assert out["param_type"] == "default"
def test_params_html_repr():
"""Check returned HTML template"""
params = ParamsDict({"a": 1, "b": 2})
assert "parameters-table" in _params_html_repr(params)
assert "estimator-table" in _params_html_repr(params)

View File

@@ -0,0 +1,317 @@
"""Utilities to get the response values of a classifier or a regressor.
It allows to make uniform checks and validation.
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
from ..base import is_classifier
from .multiclass import type_of_target
from .validation import _check_response_method, check_is_fitted
def _process_predict_proba(*, y_pred, target_type, classes, pos_label):
"""Get the response values when the response method is `predict_proba`.
This function process the `y_pred` array in the binary and multi-label cases.
In the binary case, it selects the column corresponding to the positive
class. In the multi-label case, it stacks the predictions if they are not
in the "compressed" format `(n_samples, n_outputs)`.
Parameters
----------
y_pred : ndarray
Output of `estimator.predict_proba`. The shape depends on the target type:
- for binary classification, it is a 2d array of shape `(n_samples, 2)`;
- for multiclass classification, it is a 2d array of shape
`(n_samples, n_classes)`;
- for multilabel classification, it is either a list of 2d arrays of shape
`(n_samples, 2)` (e.g. `RandomForestClassifier` or `KNeighborsClassifier`) or
an array of shape `(n_samples, n_outputs)` (e.g. `MLPClassifier` or
`RidgeClassifier`).
target_type : {"binary", "multiclass", "multilabel-indicator"}
Type of the target.
classes : ndarray of shape (n_classes,) or list of such arrays
Class labels as reported by `estimator.classes_`.
pos_label : int, float, bool or str
Only used with binary and multiclass targets.
Returns
-------
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
(n_samples, n_output)
Compressed predictions format as requested by the metrics.
"""
if target_type == "binary" and y_pred.shape[1] < 2:
# We don't handle classifiers trained on a single class.
raise ValueError(
f"Got predict_proba of shape {y_pred.shape}, but need "
"classifier with two classes."
)
if target_type == "binary":
col_idx = np.flatnonzero(classes == pos_label)[0]
return y_pred[:, col_idx]
elif target_type == "multilabel-indicator":
# Use a compress format of shape `(n_samples, n_output)`.
# Only `MLPClassifier` and `RidgeClassifier` return an array of shape
# `(n_samples, n_outputs)`.
if isinstance(y_pred, list):
# list of arrays of shape `(n_samples, 2)`
return np.vstack([p[:, -1] for p in y_pred]).T
else:
# array of shape `(n_samples, n_outputs)`
return y_pred
return y_pred
def _process_decision_function(*, y_pred, target_type, classes, pos_label):
"""Get the response values when the response method is `decision_function`.
This function process the `y_pred` array in the binary and multi-label cases.
In the binary case, it inverts the sign of the score if the positive label
is not `classes[1]`. In the multi-label case, it stacks the predictions if
they are not in the "compressed" format `(n_samples, n_outputs)`.
Parameters
----------
y_pred : ndarray
Output of `estimator.decision_function`. The shape depends on the target type:
- for binary classification, it is a 1d array of shape `(n_samples,)` where the
sign is assuming that `classes[1]` is the positive class;
- for multiclass classification, it is a 2d array of shape
`(n_samples, n_classes)`;
- for multilabel classification, it is a 2d array of shape `(n_samples,
n_outputs)`.
target_type : {"binary", "multiclass", "multilabel-indicator"}
Type of the target.
classes : ndarray of shape (n_classes,) or list of such arrays
Class labels as reported by `estimator.classes_`.
pos_label : int, float, bool or str
Only used with binary and multiclass targets.
Returns
-------
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
(n_samples, n_output)
Compressed predictions format as requested by the metrics.
"""
if target_type == "binary" and pos_label == classes[0]:
return -1 * y_pred
return y_pred
def _get_response_values(
estimator,
X,
response_method,
pos_label=None,
return_response_method_used=False,
):
"""Compute the response values of a classifier, an outlier detector, or a regressor.
The response values are predictions such that it follows the following shape:
- for binary classification, it is a 1d array of shape `(n_samples,)`;
- for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
- for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
- for outlier detection, it is a 1d array of shape `(n_samples,)`;
- for regression, it is a 1d array of shape `(n_samples,)`.
If `estimator` is a binary classifier, also return the label for the
effective positive class.
This utility is used primarily in the displays and the scikit-learn scorers.
.. versionadded:: 1.3
Parameters
----------
estimator : estimator instance
Fitted classifier, outlier detector, or regressor or a
fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a
classifier, an outlier detector, or a regressor.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
response_method : {"predict_proba", "predict_log_proba", "decision_function", \
"predict"} or list of such str
Specifies the response method to use get prediction from an estimator
(i.e. :term:`predict_proba`, :term:`predict_log_proba`,
:term:`decision_function` or :term:`predict`). Possible choices are:
- if `str`, it corresponds to the name to the method to return;
- if a list of `str`, it provides the method names in order of
preference. The method returned corresponds to the first method in
the list and which is implemented by `estimator`.
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing
the metrics. If `None` and target is 'binary', `estimators.classes_[1]` is
considered as the positive class.
return_response_method_used : bool, default=False
Whether to return the response method used to compute the response
values.
.. versionadded:: 1.4
Returns
-------
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
(n_samples, n_outputs)
Target scores calculated from the provided `response_method`
and `pos_label`.
pos_label : int, float, bool, str or None
The class considered as the positive class when computing
the metrics. Returns `None` if `estimator` is a regressor or an outlier
detector.
response_method_used : str
The response method used to compute the response values. Only returned
if `return_response_method_used` is `True`.
.. versionadded:: 1.4
Raises
------
ValueError
If `pos_label` is not a valid label.
If the shape of `y_pred` is not consistent for binary classifier.
If the response method can be applied to a classifier only and
`estimator` is a regressor.
"""
from sklearn.base import is_classifier, is_outlier_detector
if is_classifier(estimator):
prediction_method = _check_response_method(estimator, response_method)
classes = estimator.classes_
target_type = type_of_target(classes)
if target_type in ("binary", "multiclass"):
if pos_label is not None and pos_label not in classes.tolist():
raise ValueError(
f"pos_label={pos_label} is not a valid label: It should be "
f"one of {classes}"
)
elif pos_label is None and target_type == "binary":
pos_label = classes[-1]
y_pred = prediction_method(X)
if prediction_method.__name__ in ("predict_proba", "predict_log_proba"):
y_pred = _process_predict_proba(
y_pred=y_pred,
target_type=target_type,
classes=classes,
pos_label=pos_label,
)
elif prediction_method.__name__ == "decision_function":
y_pred = _process_decision_function(
y_pred=y_pred,
target_type=target_type,
classes=classes,
pos_label=pos_label,
)
elif is_outlier_detector(estimator):
prediction_method = _check_response_method(estimator, response_method)
y_pred, pos_label = prediction_method(X), None
else: # estimator is a regressor
if response_method != "predict":
raise ValueError(
f"{estimator.__class__.__name__} should either be a classifier to be "
f"used with response_method={response_method} or the response_method "
"should be 'predict'. Got a regressor with response_method="
f"{response_method} instead."
)
prediction_method = estimator.predict
y_pred, pos_label = prediction_method(X), None
if return_response_method_used:
return y_pred, pos_label, prediction_method.__name__
return y_pred, pos_label
def _get_response_values_binary(
estimator, X, response_method, pos_label=None, return_response_method_used=False
):
"""Compute the response values of a binary classifier.
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a binary classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
response_method : {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.
return_response_method_used : bool, default=False
Whether to return the response method used to compute the response
values.
.. versionadded:: 1.5
Returns
-------
y_pred : ndarray of shape (n_samples,)
Target scores calculated from the provided response_method
and pos_label.
pos_label : int, float, bool or str
The class considered as the positive class when computing
the metrics.
response_method_used : str
The response method used to compute the response values. Only returned
if `return_response_method_used` is `True`.
.. versionadded:: 1.5
"""
classification_error = "Expected 'estimator' to be a binary classifier."
check_is_fitted(estimator)
if not is_classifier(estimator):
raise ValueError(
classification_error + f" Got {estimator.__class__.__name__} instead."
)
elif len(estimator.classes_) != 2:
raise ValueError(
classification_error + f" Got {len(estimator.classes_)} classes instead."
)
if response_method == "auto":
response_method = ["predict_proba", "decision_function"]
return _get_response_values(
estimator,
X,
response_method,
pos_label=pos_label,
return_response_method_used=return_response_method_used,
)

View File

@@ -0,0 +1,76 @@
{{py:
"""
Dataset abstractions for sequential data access.
Template file for easily generate fused types consistent code using Tempita
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
Generated file: _seq_dataset.pxd
Each class is duplicated for all dtypes (float and double). The keywords
between double braces are substituted during the build.
"""
# name_suffix, c_type
dtypes = [('64', 'float64_t'),
('32', 'float32_t')]
}}
"""Dataset abstractions for sequential data access."""
from ._typedefs cimport float32_t, float64_t, intp_t, uint32_t
# SequentialDataset and its two concrete subclasses are (optionally randomized)
# iterators over the rows of a matrix X and corresponding target values y.
{{for name_suffix, c_type in dtypes}}
#------------------------------------------------------------------------------
cdef class SequentialDataset{{name_suffix}}:
cdef int current_index
cdef int[::1] index
cdef int *index_data_ptr
cdef Py_ssize_t n_samples
cdef uint32_t seed
cdef void shuffle(self, uint32_t seed) noexcept nogil
cdef int _get_next_index(self) noexcept nogil
cdef int _get_random_index(self) noexcept nogil
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
int current_index) noexcept nogil
cdef void next(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil
cdef int random(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
cdef const {{c_type}}[:, ::1] X
cdef const {{c_type}}[::1] Y
cdef const {{c_type}}[::1] sample_weights
cdef Py_ssize_t n_features
cdef intp_t X_stride
cdef {{c_type}} *X_data_ptr
cdef {{c_type}} *Y_data_ptr
cdef const int[::1] feature_indices
cdef int *feature_indices_ptr
cdef {{c_type}} *sample_weight_data
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
cdef const {{c_type}}[::1] X_data
cdef const int[::1] X_indptr
cdef const int[::1] X_indices
cdef const {{c_type}}[::1] Y
cdef const {{c_type}}[::1] sample_weights
cdef {{c_type}} *X_data_ptr
cdef int *X_indptr_ptr
cdef int *X_indices_ptr
cdef {{c_type}} *Y_data_ptr
cdef {{c_type}} *sample_weight_data
{{endfor}}

View File

@@ -0,0 +1,348 @@
{{py:
"""
Dataset abstractions for sequential data access.
Template file for easily generate fused types consistent code using Tempita
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
Generated file: _seq_dataset.pyx
Each class is duplicated for all dtypes (float and double). The keywords
between double braces are substituted during the build.
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# name_suffix, c_type, np_type
dtypes = [('64', 'float64_t', 'np.float64'),
('32', 'float32_t', 'np.float32')]
}}
"""Dataset abstractions for sequential data access."""
import numpy as np
cimport cython
from libc.limits cimport INT_MAX
from ._random cimport our_rand_r
from ._typedefs cimport float32_t, float64_t, uint32_t
{{for name_suffix, c_type, np_type in dtypes}}
#------------------------------------------------------------------------------
cdef class SequentialDataset{{name_suffix}}:
"""Base class for datasets with sequential data access.
SequentialDataset is used to iterate over the rows of a matrix X and
corresponding target values y, i.e. to iterate over samples.
There are two methods to get the next sample:
- next : Iterate sequentially (optionally randomized)
- random : Iterate randomly (with replacement)
Attributes
----------
index : np.ndarray
Index array for fast shuffling.
index_data_ptr : int
Pointer to the index array.
current_index : int
Index of current sample in ``index``.
The index of current sample in the data is given by
index_data_ptr[current_index].
n_samples : Py_ssize_t
Number of samples in the dataset.
seed : uint32_t
Seed used for random sampling. This attribute is modified at each call to the
`random` method.
"""
cdef void next(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil:
"""Get the next example ``x`` from the dataset.
This method gets the next sample looping sequentially over all samples.
The order can be shuffled with the method ``shuffle``.
Shuffling once before iterating over all samples corresponds to a
random draw without replacement. It is used for instance in SGD solver.
Parameters
----------
x_data_ptr : {{c_type}}**
A pointer to the {{c_type}} array which holds the feature
values of the next example.
x_ind_ptr : np.intc**
A pointer to the int array which holds the feature
indices of the next example.
nnz : int*
A pointer to an int holding the number of non-zero
values of the next example.
y : {{c_type}}*
The target value of the next example.
sample_weight : {{c_type}}*
The weight of the next example.
"""
cdef int current_index = self._get_next_index()
self._sample(x_data_ptr, x_ind_ptr, nnz, y, sample_weight,
current_index)
cdef int random(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) noexcept nogil:
"""Get a random example ``x`` from the dataset.
This method gets next sample chosen randomly over a uniform
distribution. It corresponds to a random draw with replacement.
It is used for instance in SAG solver.
Parameters
----------
x_data_ptr : {{c_type}}**
A pointer to the {{c_type}} array which holds the feature
values of the next example.
x_ind_ptr : np.intc**
A pointer to the int array which holds the feature
indices of the next example.
nnz : int*
A pointer to an int holding the number of non-zero
values of the next example.
y : {{c_type}}*
The target value of the next example.
sample_weight : {{c_type}}*
The weight of the next example.
Returns
-------
current_index : int
Index of current sample.
"""
cdef int current_index = self._get_random_index()
self._sample(x_data_ptr, x_ind_ptr, nnz, y, sample_weight,
current_index)
return current_index
cdef void shuffle(self, uint32_t seed) noexcept nogil:
"""Permutes the ordering of examples."""
# Fisher-Yates shuffle
cdef int *ind = self.index_data_ptr
cdef int n = self.n_samples
cdef unsigned i, j
for i in range(n - 1):
j = i + our_rand_r(&seed) % (n - i)
ind[i], ind[j] = ind[j], ind[i]
cdef int _get_next_index(self) noexcept nogil:
cdef int current_index = self.current_index
if current_index >= (self.n_samples - 1):
current_index = -1
current_index += 1
self.current_index = current_index
return self.current_index
cdef int _get_random_index(self) noexcept nogil:
cdef int n = self.n_samples
cdef int current_index = our_rand_r(&self.seed) % n
self.current_index = current_index
return current_index
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
int current_index) noexcept nogil:
pass
def _shuffle_py(self, uint32_t seed):
"""python function used for easy testing"""
self.shuffle(seed)
def _next_py(self):
"""python function used for easy testing"""
cdef int current_index = self._get_next_index()
return self._sample_py(current_index)
def _random_py(self):
"""python function used for easy testing"""
cdef int current_index = self._get_random_index()
return self._sample_py(current_index)
def _sample_py(self, int current_index):
"""python function used for easy testing"""
cdef {{c_type}}* x_data_ptr
cdef int* x_indices_ptr
cdef int nnz, j
cdef {{c_type}} y, sample_weight
# call _sample in cython
self._sample(&x_data_ptr, &x_indices_ptr, &nnz, &y, &sample_weight,
current_index)
# transform the pointed data in numpy CSR array
cdef {{c_type}}[:] x_data = np.empty(nnz, dtype={{np_type}})
cdef int[:] x_indices = np.empty(nnz, dtype=np.int32)
cdef int[:] x_indptr = np.asarray([0, nnz], dtype=np.int32)
for j in range(nnz):
x_data[j] = x_data_ptr[j]
x_indices[j] = x_indices_ptr[j]
cdef int sample_idx = self.index_data_ptr[current_index]
return (
(np.asarray(x_data), np.asarray(x_indices), np.asarray(x_indptr)),
y,
sample_weight,
sample_idx,
)
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
"""Dataset backed by a two-dimensional numpy array.
The dtype of the numpy array is expected to be ``{{np_type}}`` ({{c_type}})
and C-style memory layout.
"""
def __cinit__(
self,
const {{c_type}}[:, ::1] X,
const {{c_type}}[::1] Y,
const {{c_type}}[::1] sample_weights,
uint32_t seed=1,
):
"""A ``SequentialDataset`` backed by a two-dimensional numpy array.
Parameters
----------
X : ndarray, dtype={{c_type}}, ndim=2, mode='c'
The sample array, of shape(n_samples, n_features)
Y : ndarray, dtype={{c_type}}, ndim=1, mode='c'
The target array, of shape(n_samples, )
sample_weights : ndarray, dtype={{c_type}}, ndim=1, mode='c'
The weight of each sample, of shape(n_samples,)
"""
if X.shape[0] > INT_MAX or X.shape[1] > INT_MAX:
raise ValueError("More than %d samples or features not supported;"
" got (%d, %d)."
% (INT_MAX, X.shape[0], X.shape[1]))
# keep a reference to the data to prevent garbage collection
self.X = X
self.Y = Y
self.sample_weights = sample_weights
self.n_samples = X.shape[0]
self.n_features = X.shape[1]
self.feature_indices = np.arange(0, self.n_features, dtype=np.intc)
self.feature_indices_ptr = <int *> &self.feature_indices[0]
self.current_index = -1
self.X_stride = X.strides[0] // X.itemsize
self.X_data_ptr = <{{c_type}} *> &X[0, 0]
self.Y_data_ptr = <{{c_type}} *> &Y[0]
self.sample_weight_data = <{{c_type}} *> &sample_weights[0]
# Use index array for fast shuffling
self.index = np.arange(0, self.n_samples, dtype=np.intc)
self.index_data_ptr = <int *> &self.index[0]
# seed should not be 0 for our_rand_r
self.seed = max(seed, 1)
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
int current_index) noexcept nogil:
cdef long long sample_idx = self.index_data_ptr[current_index]
cdef long long offset = sample_idx * self.X_stride
y[0] = self.Y_data_ptr[sample_idx]
x_data_ptr[0] = self.X_data_ptr + offset
x_ind_ptr[0] = self.feature_indices_ptr
nnz[0] = self.n_features
sample_weight[0] = self.sample_weight_data[sample_idx]
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
"""A ``SequentialDataset`` backed by a scipy sparse CSR matrix. """
def __cinit__(
self,
const {{c_type}}[::1] X_data,
const int[::1] X_indptr,
const int[::1] X_indices,
const {{c_type}}[::1] Y,
const {{c_type}}[::1] sample_weights,
uint32_t seed=1,
):
"""Dataset backed by a scipy sparse CSR matrix.
The feature indices of ``x`` are given by x_ind_ptr[0:nnz].
The corresponding feature values are given by
x_data_ptr[0:nnz].
Parameters
----------
X_data : ndarray, dtype={{c_type}}, ndim=1, mode='c'
The data array of the CSR features matrix.
X_indptr : ndarray, dtype=np.intc, ndim=1, mode='c'
The index pointer array of the CSR features matrix.
X_indices : ndarray, dtype=np.intc, ndim=1, mode='c'
The column indices array of the CSR features matrix.
Y : ndarray, dtype={{c_type}}, ndim=1, mode='c'
The target values.
sample_weights : ndarray, dtype={{c_type}}, ndim=1, mode='c'
The weight of each sample.
"""
# keep a reference to the data to prevent garbage collection
self.X_data = X_data
self.X_indptr = X_indptr
self.X_indices = X_indices
self.Y = Y
self.sample_weights = sample_weights
self.n_samples = Y.shape[0]
self.current_index = -1
self.X_data_ptr = <{{c_type}} *> &X_data[0]
self.X_indptr_ptr = <int *> &X_indptr[0]
self.X_indices_ptr = <int *> &X_indices[0]
self.Y_data_ptr = <{{c_type}} *> &Y[0]
self.sample_weight_data = <{{c_type}} *> &sample_weights[0]
# Use index array for fast shuffling
self.index = np.arange(self.n_samples, dtype=np.intc)
self.index_data_ptr = <int *> &self.index[0]
# seed should not be 0 for our_rand_r
self.seed = max(seed, 1)
cdef void _sample(self, {{c_type}} **x_data_ptr, int **x_ind_ptr,
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight,
int current_index) noexcept nogil:
cdef long long sample_idx = self.index_data_ptr[current_index]
cdef long long offset = self.X_indptr_ptr[sample_idx]
y[0] = self.Y_data_ptr[sample_idx]
x_data_ptr[0] = self.X_data_ptr + offset
x_ind_ptr[0] = self.X_indices_ptr + offset
nnz[0] = self.X_indptr_ptr[sample_idx + 1] - offset
sample_weight[0] = self.sample_weight_data[sample_idx]
{{endfor}}

View File

@@ -0,0 +1,460 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import importlib
from functools import wraps
from typing import Protocol, runtime_checkable
import numpy as np
from scipy.sparse import issparse
from .._config import get_config
from ._available_if import available_if
def check_library_installed(library):
"""Check library is installed."""
try:
return importlib.import_module(library)
except ImportError as exc:
raise ImportError(
f"Setting output container to '{library}' requires {library} to be"
" installed"
) from exc
def get_columns(columns):
if callable(columns):
try:
return columns()
except Exception:
return None
return columns
@runtime_checkable
class ContainerAdapterProtocol(Protocol):
container_lib: str
def create_container(self, X_output, X_original, columns, inplace=False):
"""Create container from `X_output` with additional metadata.
Parameters
----------
X_output : {ndarray, dataframe}
Data to wrap.
X_original : {ndarray, dataframe}
Original input dataframe. This is used to extract the metadata that should
be passed to `X_output`, e.g. pandas row index.
columns : callable, ndarray, or None
The column names or a callable that returns the column names. The
callable is useful if the column names require some computation. If `None`,
then no columns are passed to the container's constructor.
inplace : bool, default=False
Whether or not we intend to modify `X_output` in-place. However, it does
not guarantee that we return the same object if the in-place operation
is not possible.
Returns
-------
wrapped_output : container_type
`X_output` wrapped into the container type.
"""
def is_supported_container(self, X):
"""Return True if X is a supported container.
Parameters
----------
Xs: container
Containers to be checked.
Returns
-------
is_supported_container : bool
True if X is a supported container.
"""
def rename_columns(self, X, columns):
"""Rename columns in `X`.
Parameters
----------
X : container
Container which columns is updated.
columns : ndarray of str
Columns to update the `X`'s columns with.
Returns
-------
updated_container : container
Container with new names.
"""
def hstack(self, Xs):
"""Stack containers horizontally (column-wise).
Parameters
----------
Xs : list of containers
List of containers to stack.
Returns
-------
stacked_Xs : container
Stacked containers.
"""
class PandasAdapter:
container_lib = "pandas"
def create_container(self, X_output, X_original, columns, inplace=True):
pd = check_library_installed("pandas")
columns = get_columns(columns)
if not inplace or not isinstance(X_output, pd.DataFrame):
# In all these cases, we need to create a new DataFrame
# Unfortunately, we cannot use `getattr(container, "index")`
# because `list` exposes an `index` attribute.
if isinstance(X_output, pd.DataFrame):
index = X_output.index
elif isinstance(X_original, (pd.DataFrame, pd.Series)):
index = X_original.index
else:
index = None
# We don't pass columns here because it would intend columns selection
# instead of renaming.
X_output = pd.DataFrame(X_output, index=index, copy=not inplace)
if columns is not None:
return self.rename_columns(X_output, columns)
return X_output
def is_supported_container(self, X):
pd = check_library_installed("pandas")
return isinstance(X, pd.DataFrame)
def rename_columns(self, X, columns):
# we cannot use `rename` since it takes a dictionary and at this stage we have
# potentially duplicate column names in `X`
X.columns = columns
return X
def hstack(self, Xs):
pd = check_library_installed("pandas")
return pd.concat(Xs, axis=1)
class PolarsAdapter:
container_lib = "polars"
def create_container(self, X_output, X_original, columns, inplace=True):
pl = check_library_installed("polars")
columns = get_columns(columns)
columns = columns.tolist() if isinstance(columns, np.ndarray) else columns
if not inplace or not isinstance(X_output, pl.DataFrame):
# In all these cases, we need to create a new DataFrame
return pl.DataFrame(X_output, schema=columns, orient="row")
if columns is not None:
return self.rename_columns(X_output, columns)
return X_output
def is_supported_container(self, X):
pl = check_library_installed("polars")
return isinstance(X, pl.DataFrame)
def rename_columns(self, X, columns):
# we cannot use `rename` since it takes a dictionary and at this stage we have
# potentially duplicate column names in `X`
X.columns = columns
return X
def hstack(self, Xs):
pl = check_library_installed("polars")
return pl.concat(Xs, how="horizontal")
class ContainerAdaptersManager:
def __init__(self):
self.adapters = {}
@property
def supported_outputs(self):
return {"default"} | set(self.adapters)
def register(self, adapter):
self.adapters[adapter.container_lib] = adapter
ADAPTERS_MANAGER = ContainerAdaptersManager()
ADAPTERS_MANAGER.register(PandasAdapter())
ADAPTERS_MANAGER.register(PolarsAdapter())
def _get_adapter_from_container(container):
"""Get the adapter that knows how to handle such container.
See :class:`sklearn.utils._set_output.ContainerAdapterProtocol` for more
details.
"""
module_name = container.__class__.__module__.split(".")[0]
try:
return ADAPTERS_MANAGER.adapters[module_name]
except KeyError as exc:
available_adapters = list(ADAPTERS_MANAGER.adapters.keys())
raise ValueError(
"The container does not have a registered adapter in scikit-learn. "
f"Available adapters are: {available_adapters} while the container "
f"provided is: {container!r}."
) from exc
def _get_container_adapter(method, estimator=None):
"""Get container adapter."""
dense_config = _get_output_config(method, estimator)["dense"]
try:
return ADAPTERS_MANAGER.adapters[dense_config]
except KeyError:
return None
def _get_output_config(method, estimator=None):
"""Get output config based on estimator and global configuration.
Parameters
----------
method : {"transform"}
Estimator's method for which the output container is looked up.
estimator : estimator instance or None
Estimator to get the output configuration from. If `None`, check global
configuration is used.
Returns
-------
config : dict
Dictionary with keys:
- "dense": specifies the dense container for `method`. This can be
`"default"` or `"pandas"`.
"""
est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
if method in est_sklearn_output_config:
dense_config = est_sklearn_output_config[method]
else:
dense_config = get_config()[f"{method}_output"]
supported_outputs = ADAPTERS_MANAGER.supported_outputs
if dense_config not in supported_outputs:
raise ValueError(
f"output config must be in {sorted(supported_outputs)}, got {dense_config}"
)
return {"dense": dense_config}
def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
"""Wrap output with container based on an estimator's or global config.
Parameters
----------
method : {"transform"}
Estimator's method to get container output for.
data_to_wrap : {ndarray, dataframe}
Data to wrap with container.
original_input : {ndarray, dataframe}
Original input of function.
estimator : estimator instance
Estimator with to get the output configuration from.
Returns
-------
output : {ndarray, dataframe}
If the output config is "default" or the estimator is not configured
for wrapping return `data_to_wrap` unchanged.
If the output config is "pandas", return `data_to_wrap` as a pandas
DataFrame.
"""
output_config = _get_output_config(method, estimator)
if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
return data_to_wrap
dense_config = output_config["dense"]
if issparse(data_to_wrap):
raise ValueError(
"The transformer outputs a scipy sparse matrix. "
"Try to set the transformer output to a dense array or disable "
f"{dense_config.capitalize()} output with set_output(transform='default')."
)
adapter = ADAPTERS_MANAGER.adapters[dense_config]
return adapter.create_container(
data_to_wrap,
original_input,
columns=estimator.get_feature_names_out,
)
def _wrap_method_output(f, method):
"""Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
@wraps(f)
def wrapped(self, X, *args, **kwargs):
data_to_wrap = f(self, X, *args, **kwargs)
if isinstance(data_to_wrap, tuple):
# only wrap the first output for cross decomposition
return_tuple = (
_wrap_data_with_container(method, data_to_wrap[0], X, self),
*data_to_wrap[1:],
)
# Support for namedtuples `_make` is a documented API for namedtuples:
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
if hasattr(type(data_to_wrap), "_make"):
return type(data_to_wrap)._make(return_tuple)
return return_tuple
return _wrap_data_with_container(method, data_to_wrap, X, self)
return wrapped
def _auto_wrap_is_configured(estimator):
"""Return True if estimator is configured for auto-wrapping the transform method.
`_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
is manually disabled.
"""
auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
return (
hasattr(estimator, "get_feature_names_out")
and "transform" in auto_wrap_output_keys
)
class _SetOutputMixin:
"""Mixin that dynamically wraps methods to return container based on config.
Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
it based on `set_output` of the global configuration.
`set_output` is only defined if `get_feature_names_out` is defined and
`auto_wrap_output_keys` is the default value.
"""
def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
super().__init_subclass__(**kwargs)
# Dynamically wraps `transform` and `fit_transform` and configure it's
# output based on `set_output`.
if not (
isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
):
raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
if auto_wrap_output_keys is None:
cls._sklearn_auto_wrap_output_keys = set()
return
# Mapping from method to key in configurations
method_to_key = {
"transform": "transform",
"fit_transform": "transform",
}
cls._sklearn_auto_wrap_output_keys = set()
for method, key in method_to_key.items():
if not hasattr(cls, method) or key not in auto_wrap_output_keys:
continue
cls._sklearn_auto_wrap_output_keys.add(key)
# Only wrap methods defined by cls itself
if method not in cls.__dict__:
continue
wrapped_method = _wrap_method_output(getattr(cls, method), key)
setattr(cls, method, wrapped_method)
@available_if(_auto_wrap_is_configured)
def set_output(self, *, transform=None):
"""Set output container.
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
for an example on how to use the API.
Parameters
----------
transform : {"default", "pandas", "polars"}, default=None
Configure output of `transform` and `fit_transform`.
- `"default"`: Default output format of a transformer
- `"pandas"`: DataFrame output
- `"polars"`: Polars output
- `None`: Transform configuration is unchanged
.. versionadded:: 1.4
`"polars"` option was added.
Returns
-------
self : estimator instance
Estimator instance.
"""
if transform is None:
return self
if not hasattr(self, "_sklearn_output_config"):
self._sklearn_output_config = {}
self._sklearn_output_config["transform"] = transform
return self
def _safe_set_output(estimator, *, transform=None):
"""Safely call estimator.set_output and error if it not available.
This is used by meta-estimators to set the output for child estimators.
Parameters
----------
estimator : estimator instance
Estimator instance.
transform : {"default", "pandas", "polars"}, default=None
Configure output of the following estimator's methods:
- `"transform"`
- `"fit_transform"`
If `None`, this operation is a no-op.
Returns
-------
estimator : estimator instance
Estimator instance.
"""
set_output_for_transform = hasattr(estimator, "transform") or (
hasattr(estimator, "fit_transform") and transform is not None
)
if not set_output_for_transform:
# If estimator can not transform, then `set_output` does not need to be
# called.
return
if not hasattr(estimator, "set_output"):
raise ValueError(
f"Unable to configure output for {estimator} because `set_output` "
"is not available."
)
return estimator.set_output(transform=transform)

View File

@@ -0,0 +1,115 @@
"""
Utility methods to print system info for debugging
adapted from :func:`pandas.show_versions`
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import platform
import sys
from threadpoolctl import threadpool_info
from .. import __version__
from ._openmp_helpers import _openmp_parallelism_enabled
def _get_sys_info():
"""System information
Returns
-------
sys_info : dict
system and Python version information
"""
python = sys.version.replace("\n", " ")
blob = [
("python", python),
("executable", sys.executable),
("machine", platform.platform()),
]
return dict(blob)
def _get_deps_info():
"""Overview of the installed version of main dependencies
This function does not import the modules to collect the version numbers
but instead relies on standard Python package metadata.
Returns
-------
deps_info: dict
version information on relevant Python libraries
"""
deps = [
"pip",
"setuptools",
"numpy",
"scipy",
"Cython",
"pandas",
"matplotlib",
"joblib",
"threadpoolctl",
]
deps_info = {
"sklearn": __version__,
}
from importlib.metadata import PackageNotFoundError, version
for modname in deps:
try:
deps_info[modname] = version(modname)
except PackageNotFoundError:
deps_info[modname] = None
return deps_info
def show_versions():
"""Print useful debugging information"
.. versionadded:: 0.20
Examples
--------
>>> from sklearn import show_versions
>>> show_versions() # doctest: +SKIP
"""
sys_info = _get_sys_info()
deps_info = _get_deps_info()
print("\nSystem:")
for k, stat in sys_info.items():
print("{k:>10}: {stat}".format(k=k, stat=stat))
print("\nPython dependencies:")
for k, stat in deps_info.items():
print("{k:>13}: {stat}".format(k=k, stat=stat))
print(
"\n{k}: {stat}".format(
k="Built with OpenMP", stat=_openmp_parallelism_enabled()
)
)
# show threadpoolctl results
threadpool_results = threadpool_info()
if threadpool_results:
print()
print("threadpoolctl info:")
for i, result in enumerate(threadpool_results):
for key, val in result.items():
print(f"{key:>15}: {val}")
if i != len(threadpool_results) - 1:
print()

View File

@@ -0,0 +1,9 @@
from ._typedefs cimport intp_t
from cython cimport floating
cdef int simultaneous_sort(
floating *dist,
intp_t *idx,
intp_t size,
) noexcept nogil

View File

@@ -0,0 +1,93 @@
from cython cimport floating
cdef inline void dual_swap(
floating* darr,
intp_t *iarr,
intp_t a,
intp_t b,
) noexcept nogil:
"""Swap the values at index a and b of both darr and iarr"""
cdef floating dtmp = darr[a]
darr[a] = darr[b]
darr[b] = dtmp
cdef intp_t itmp = iarr[a]
iarr[a] = iarr[b]
iarr[b] = itmp
cdef int simultaneous_sort(
floating* values,
intp_t* indices,
intp_t size,
) noexcept nogil:
"""
Perform a recursive quicksort on the values array as to sort them ascendingly.
This simultaneously performs the swaps on both the values and the indices arrays.
The numpy equivalent is:
def simultaneous_sort(dist, idx):
i = np.argsort(dist)
return dist[i], idx[i]
Notes
-----
Arrays are manipulated via a pointer to there first element and their size
as to ease the processing of dynamically allocated buffers.
"""
# TODO: In order to support discrete distance metrics, we need to have a
# simultaneous sort which breaks ties on indices when distances are identical.
# The best might be using a std::stable_sort and a Comparator which might need
# an Array of Structures (AoS) instead of the Structure of Arrays (SoA)
# currently used.
cdef:
intp_t pivot_idx, i, store_idx
floating pivot_val
# in the small-array case, do things efficiently
if size <= 1:
pass
elif size == 2:
if values[0] > values[1]:
dual_swap(values, indices, 0, 1)
elif size == 3:
if values[0] > values[1]:
dual_swap(values, indices, 0, 1)
if values[1] > values[2]:
dual_swap(values, indices, 1, 2)
if values[0] > values[1]:
dual_swap(values, indices, 0, 1)
else:
# Determine the pivot using the median-of-three rule.
# The smallest of the three is moved to the beginning of the array,
# the middle (the pivot value) is moved to the end, and the largest
# is moved to the pivot index.
pivot_idx = size // 2
if values[0] > values[size - 1]:
dual_swap(values, indices, 0, size - 1)
if values[size - 1] > values[pivot_idx]:
dual_swap(values, indices, size - 1, pivot_idx)
if values[0] > values[size - 1]:
dual_swap(values, indices, 0, size - 1)
pivot_val = values[size - 1]
# Partition indices about pivot. At the end of this operation,
# pivot_idx will contain the pivot value, everything to the left
# will be smaller, and everything to the right will be larger.
store_idx = 0
for i in range(size - 1):
if values[i] < pivot_val:
dual_swap(values, indices, i, store_idx)
store_idx += 1
dual_swap(values, indices, store_idx, size - 1)
pivot_idx = store_idx
# Recursively sort each side of the pivot
if pivot_idx > 1:
simultaneous_sort(values, indices, pivot_idx)
if pivot_idx + 2 < size:
simultaneous_sort(values + pivot_idx + 1,
indices + pivot_idx + 1,
size - pivot_idx - 1)
return 0

Some files were not shown because too many files have changed in this diff Show More