add read me
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from ._frozen import FrozenEstimator
|
||||
|
||||
__all__ = ["FrozenEstimator"]
|
||||
Binary file not shown.
Binary file not shown.
166
venv/lib/python3.12/site-packages/sklearn/frozen/_frozen.py
Normal file
166
venv/lib/python3.12/site-packages/sklearn/frozen/_frozen.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from ..base import BaseEstimator
|
||||
from ..exceptions import NotFittedError
|
||||
from ..utils import get_tags
|
||||
from ..utils.metaestimators import available_if
|
||||
from ..utils.validation import check_is_fitted
|
||||
|
||||
|
||||
def _estimator_has(attr):
|
||||
"""Check that final_estimator has `attr`.
|
||||
|
||||
Used together with `available_if`.
|
||||
"""
|
||||
|
||||
def check(self):
|
||||
# raise original `AttributeError` if `attr` does not exist
|
||||
getattr(self.estimator, attr)
|
||||
return True
|
||||
|
||||
return check
|
||||
|
||||
|
||||
class FrozenEstimator(BaseEstimator):
|
||||
"""Estimator that wraps a fitted estimator to prevent re-fitting.
|
||||
|
||||
This meta-estimator takes an estimator and freezes it, in the sense that calling
|
||||
`fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled.
|
||||
All other methods are delegated to the original estimator and original estimator's
|
||||
attributes are accessible as well.
|
||||
|
||||
This is particularly useful when you have a fitted or a pre-trained model as a
|
||||
transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this
|
||||
step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator
|
||||
The estimator which is to be kept frozen.
|
||||
|
||||
See Also
|
||||
--------
|
||||
None: No similar entry in the scikit-learn documentation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.frozen import FrozenEstimator
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> clf = LogisticRegression(random_state=0).fit(X, y)
|
||||
>>> frozen_clf = FrozenEstimator(clf)
|
||||
>>> frozen_clf.fit(X, y) # No-op
|
||||
FrozenEstimator(estimator=LogisticRegression(random_state=0))
|
||||
>>> frozen_clf.predict(X) # Predictions from `clf.predict`
|
||||
array(...)
|
||||
"""
|
||||
|
||||
def __init__(self, estimator):
|
||||
self.estimator = estimator
|
||||
|
||||
@available_if(_estimator_has("__getitem__"))
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
"""__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \
|
||||
:class:`~sklearn.compose.ColumnTransformer`.
|
||||
"""
|
||||
return self.estimator.__getitem__(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# `estimator`'s attributes are now accessible except `fit_predict` and
|
||||
# `fit_transform`
|
||||
if name in ["fit_predict", "fit_transform"]:
|
||||
raise AttributeError(f"{name} is not available for frozen estimators.")
|
||||
return getattr(self.estimator, name)
|
||||
|
||||
def __sklearn_clone__(self):
|
||||
return self
|
||||
|
||||
def __sklearn_is_fitted__(self):
|
||||
try:
|
||||
check_is_fitted(self.estimator)
|
||||
return True
|
||||
except NotFittedError:
|
||||
return False
|
||||
|
||||
def fit(self, X, y, *args, **kwargs):
|
||||
"""No-op.
|
||||
|
||||
As a frozen estimator, calling `fit` has no effect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : object
|
||||
Ignored.
|
||||
|
||||
y : object
|
||||
Ignored.
|
||||
|
||||
*args : tuple
|
||||
Additional positional arguments. Ignored, but present for API compatibility
|
||||
with `self.estimator`.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keyword arguments. Ignored, but present for API compatibility
|
||||
with `self.estimator`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Returns the instance itself.
|
||||
"""
|
||||
check_is_fitted(self.estimator)
|
||||
return self
|
||||
|
||||
def set_params(self, **kwargs):
|
||||
"""Set the parameters of this estimator.
|
||||
|
||||
The only valid key here is `estimator`. You cannot set the parameters of the
|
||||
inner estimator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs : dict
|
||||
Estimator parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : FrozenEstimator
|
||||
This estimator.
|
||||
"""
|
||||
estimator = kwargs.pop("estimator", None)
|
||||
if estimator is not None:
|
||||
self.estimator = estimator
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
"You cannot set parameters of the inner estimator in a frozen "
|
||||
"estimator since calling `fit` has no effect. You can use "
|
||||
"`frozenestimator.estimator.set_params` to set parameters of the inner "
|
||||
"estimator."
|
||||
)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
"""Get parameters for this estimator.
|
||||
|
||||
Returns a `{"estimator": estimator}` dict. The parameters of the inner
|
||||
estimator are not included.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deep : bool, default=True
|
||||
Ignored.
|
||||
|
||||
Returns
|
||||
-------
|
||||
params : dict
|
||||
Parameter names mapped to their values.
|
||||
"""
|
||||
return {"estimator": self.estimator}
|
||||
|
||||
def __sklearn_tags__(self):
|
||||
tags = deepcopy(get_tags(self.estimator))
|
||||
tags._skip_test = True
|
||||
return tags
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,223 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from sklearn import config_context
|
||||
from sklearn.base import (
|
||||
BaseEstimator,
|
||||
clone,
|
||||
is_classifier,
|
||||
is_clusterer,
|
||||
is_outlier_detector,
|
||||
is_regressor,
|
||||
)
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
from sklearn.exceptions import NotFittedError, UnsetMetadataPassedError
|
||||
from sklearn.frozen import FrozenEstimator
|
||||
from sklearn.linear_model import LinearRegression, LogisticRegression
|
||||
from sklearn.neighbors import LocalOutlierFactor
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import RobustScaler, StandardScaler
|
||||
from sklearn.utils._testing import set_random_state
|
||||
from sklearn.utils.validation import check_is_fitted
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regression_dataset():
|
||||
return make_regression()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def classification_dataset():
|
||||
return make_classification()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"estimator, dataset",
|
||||
[
|
||||
(LinearRegression(), "regression_dataset"),
|
||||
(LogisticRegression(), "classification_dataset"),
|
||||
(make_pipeline(StandardScaler(), LinearRegression()), "regression_dataset"),
|
||||
(
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
"classification_dataset",
|
||||
),
|
||||
(StandardScaler(), "regression_dataset"),
|
||||
(KMeans(), "regression_dataset"),
|
||||
(LocalOutlierFactor(), "regression_dataset"),
|
||||
(
|
||||
make_column_transformer(
|
||||
(StandardScaler(), [0]),
|
||||
(RobustScaler(), [1]),
|
||||
),
|
||||
"regression_dataset",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"method",
|
||||
["predict", "predict_proba", "predict_log_proba", "decision_function", "transform"],
|
||||
)
|
||||
def test_frozen_methods(estimator, dataset, request, method):
|
||||
"""Test that frozen.fit doesn't do anything, and that all other methods are
|
||||
exposed by the frozen estimator and return the same values as the estimator.
|
||||
"""
|
||||
X, y = request.getfixturevalue(dataset)
|
||||
set_random_state(estimator)
|
||||
estimator.fit(X, y)
|
||||
frozen = FrozenEstimator(estimator)
|
||||
# this should be no-op
|
||||
frozen.fit([[1]], [1])
|
||||
|
||||
if hasattr(estimator, method):
|
||||
assert_array_equal(getattr(estimator, method)(X), getattr(frozen, method)(X))
|
||||
|
||||
assert is_classifier(estimator) == is_classifier(frozen)
|
||||
assert is_regressor(estimator) == is_regressor(frozen)
|
||||
assert is_clusterer(estimator) == is_clusterer(frozen)
|
||||
assert is_outlier_detector(estimator) == is_outlier_detector(frozen)
|
||||
|
||||
|
||||
@config_context(enable_metadata_routing=True)
|
||||
def test_frozen_metadata_routing(regression_dataset):
|
||||
"""Test that metadata routing works with frozen estimators."""
|
||||
|
||||
class ConsumesMetadata(BaseEstimator):
|
||||
def __init__(self, on_fit=None, on_predict=None):
|
||||
self.on_fit = on_fit
|
||||
self.on_predict = on_predict
|
||||
|
||||
def fit(self, X, y, metadata=None):
|
||||
if self.on_fit:
|
||||
assert metadata is not None
|
||||
self.fitted_ = True
|
||||
return self
|
||||
|
||||
def predict(self, X, metadata=None):
|
||||
if self.on_predict:
|
||||
assert metadata is not None
|
||||
return np.ones(len(X))
|
||||
|
||||
X, y = regression_dataset
|
||||
pipeline = make_pipeline(
|
||||
ConsumesMetadata(on_fit=True, on_predict=True)
|
||||
.set_fit_request(metadata=True)
|
||||
.set_predict_request(metadata=True)
|
||||
)
|
||||
|
||||
pipeline.fit(X, y, metadata="test")
|
||||
frozen = FrozenEstimator(pipeline)
|
||||
pipeline.predict(X, metadata="test")
|
||||
frozen.predict(X, metadata="test")
|
||||
|
||||
frozen["consumesmetadata"].set_predict_request(metadata=False)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=re.escape(
|
||||
"Pipeline.predict got unexpected argument(s) {'metadata'}, which are not "
|
||||
"routed to any object."
|
||||
),
|
||||
):
|
||||
frozen.predict(X, metadata="test")
|
||||
|
||||
frozen["consumesmetadata"].set_predict_request(metadata=None)
|
||||
with pytest.raises(UnsetMetadataPassedError):
|
||||
frozen.predict(X, metadata="test")
|
||||
|
||||
|
||||
def test_composite_fit(classification_dataset):
|
||||
"""Test that calling fit_transform and fit_predict doesn't call fit."""
|
||||
|
||||
class Estimator(BaseEstimator):
|
||||
def fit(self, X, y):
|
||||
try:
|
||||
self._fit_counter += 1
|
||||
except AttributeError:
|
||||
self._fit_counter = 1
|
||||
return self
|
||||
|
||||
def fit_transform(self, X, y=None):
|
||||
# only here to test that it doesn't get called
|
||||
... # pragma: no cover
|
||||
|
||||
def fit_predict(self, X, y=None):
|
||||
# only here to test that it doesn't get called
|
||||
... # pragma: no cover
|
||||
|
||||
X, y = classification_dataset
|
||||
est = Estimator().fit(X, y)
|
||||
frozen = FrozenEstimator(est)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
frozen.fit_predict(X, y)
|
||||
with pytest.raises(AttributeError):
|
||||
frozen.fit_transform(X, y)
|
||||
|
||||
assert frozen._fit_counter == 1
|
||||
|
||||
|
||||
def test_clone_frozen(regression_dataset):
|
||||
"""Test that cloning a frozen estimator keeps the frozen state."""
|
||||
X, y = regression_dataset
|
||||
estimator = LinearRegression().fit(X, y)
|
||||
frozen = FrozenEstimator(estimator)
|
||||
cloned = clone(frozen)
|
||||
assert cloned.estimator is estimator
|
||||
|
||||
|
||||
def test_check_is_fitted(regression_dataset):
|
||||
"""Test that check_is_fitted works on frozen estimators."""
|
||||
X, y = regression_dataset
|
||||
|
||||
estimator = LinearRegression()
|
||||
frozen = FrozenEstimator(estimator)
|
||||
with pytest.raises(NotFittedError):
|
||||
check_is_fitted(frozen)
|
||||
|
||||
estimator = LinearRegression().fit(X, y)
|
||||
frozen = FrozenEstimator(estimator)
|
||||
check_is_fitted(frozen)
|
||||
|
||||
|
||||
def test_frozen_tags():
|
||||
"""Test that frozen estimators have the same tags as the original estimator
|
||||
except for the skip_test tag."""
|
||||
|
||||
class Estimator(BaseEstimator):
|
||||
def __sklearn_tags__(self):
|
||||
tags = super().__sklearn_tags__()
|
||||
tags.input_tags.categorical = True
|
||||
return tags
|
||||
|
||||
estimator = Estimator()
|
||||
frozen = FrozenEstimator(estimator)
|
||||
frozen_tags = frozen.__sklearn_tags__()
|
||||
estimator_tags = estimator.__sklearn_tags__()
|
||||
|
||||
assert frozen_tags._skip_test is True
|
||||
assert estimator_tags._skip_test is False
|
||||
|
||||
assert estimator_tags.input_tags.categorical is True
|
||||
assert frozen_tags.input_tags.categorical is True
|
||||
|
||||
|
||||
def test_frozen_params():
|
||||
"""Test that FrozenEstimator only exposes the estimator parameter."""
|
||||
est = LogisticRegression()
|
||||
frozen = FrozenEstimator(est)
|
||||
|
||||
with pytest.raises(ValueError, match="You cannot set parameters of the inner"):
|
||||
frozen.set_params(estimator__C=1)
|
||||
|
||||
assert frozen.get_params() == {"estimator": est}
|
||||
|
||||
other_est = LocalOutlierFactor()
|
||||
frozen.set_params(estimator=other_est)
|
||||
assert frozen.get_params() == {"estimator": other_est}
|
||||
Reference in New Issue
Block a user