add read me

This commit is contained in:
2026-01-09 10:28:44 +11:00
commit edaf914b73
13417 changed files with 2952119 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

View File

@@ -0,0 +1,499 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from itertools import product
import numpy as np
from ...base import is_classifier
from ...utils._optional_dependencies import check_matplotlib_support
from ...utils._plotting import _validate_style_kwargs
from ...utils.multiclass import unique_labels
from .. import confusion_matrix
class ConfusionMatrixDisplay:
"""Confusion Matrix visualization.
It is recommended to use
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
attributes.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <confusion_matrix>`.
Parameters
----------
confusion_matrix : ndarray of shape (n_classes, n_classes)
Confusion matrix.
display_labels : ndarray of shape (n_classes,), default=None
Display labels for plot. If None, display labels are set from 0 to
`n_classes - 1`.
Attributes
----------
im_ : matplotlib AxesImage
Image representing the confusion matrix.
text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \
or None
Array of matplotlib axes. `None` if `include_values` is false.
ax_ : matplotlib Axes
Axes with confusion matrix.
figure_ : matplotlib Figure
Figure containing the confusion matrix.
See Also
--------
confusion_matrix : Compute Confusion Matrix to evaluate the accuracy of a
classification.
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
given an estimator, the data, and the label.
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
given the true and predicted labels.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> predictions = clf.predict(X_test)
>>> cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
>>> disp = ConfusionMatrixDisplay(confusion_matrix=cm,
... display_labels=clf.classes_)
>>> disp.plot()
<...>
>>> plt.show()
"""
def __init__(self, confusion_matrix, *, display_labels=None):
self.confusion_matrix = confusion_matrix
self.display_labels = display_labels
def plot(
self,
*,
include_values=True,
cmap="viridis",
xticks_rotation="horizontal",
values_format=None,
ax=None,
colorbar=True,
im_kw=None,
text_kw=None,
):
"""Plot visualization.
Parameters
----------
include_values : bool, default=True
Includes values in confusion matrix.
cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.
xticks_rotation : {'vertical', 'horizontal'} or float, \
default='horizontal'
Rotation of xtick labels.
values_format : str, default=None
Format specification for values in confusion matrix. If `None`,
the format specification is 'd' or '.2g' whichever is shorter.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
colorbar : bool, default=True
Whether or not to add a colorbar to the plot.
im_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
text_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.text` call.
.. versionadded:: 1.2
Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
Returns a :class:`~sklearn.metrics.ConfusionMatrixDisplay` instance
that contains all the information to plot the confusion matrix.
"""
check_matplotlib_support("ConfusionMatrixDisplay.plot")
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure
cm = self.confusion_matrix
n_classes = cm.shape[0]
default_im_kw = dict(interpolation="nearest", cmap=cmap)
im_kw = im_kw or {}
im_kw = _validate_style_kwargs(default_im_kw, im_kw)
text_kw = text_kw or {}
self.im_ = ax.imshow(cm, **im_kw)
self.text_ = None
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
if include_values:
self.text_ = np.empty_like(cm, dtype=object)
# print text with appropriate color depending on background
thresh = (cm.max() + cm.min()) / 2.0
for i, j in product(range(n_classes), range(n_classes)):
color = cmap_max if cm[i, j] < thresh else cmap_min
if values_format is None:
text_cm = format(cm[i, j], ".2g")
if cm.dtype.kind != "f":
text_d = format(cm[i, j], "d")
if len(text_d) < len(text_cm):
text_cm = text_d
else:
text_cm = format(cm[i, j], values_format)
default_text_kwargs = dict(ha="center", va="center", color=color)
text_kwargs = _validate_style_kwargs(default_text_kwargs, text_kw)
self.text_[i, j] = ax.text(j, i, text_cm, **text_kwargs)
if self.display_labels is None:
display_labels = np.arange(n_classes)
else:
display_labels = self.display_labels
if colorbar:
fig.colorbar(self.im_, ax=ax)
ax.set(
xticks=np.arange(n_classes),
yticks=np.arange(n_classes),
xticklabels=display_labels,
yticklabels=display_labels,
ylabel="True label",
xlabel="Predicted label",
)
ax.set_ylim((n_classes - 0.5, -0.5))
plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
self.figure_ = fig
self.ax_ = ax
return self
@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
labels=None,
sample_weight=None,
normalize=None,
display_labels=None,
include_values=True,
xticks_rotation="horizontal",
values_format=None,
cmap="viridis",
ax=None,
colorbar=True,
im_kw=None,
text_kw=None,
):
"""Plot Confusion Matrix given an estimator and some data.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <confusion_matrix>`.
.. versionadded:: 1.0
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
labels : array-like of shape (n_classes,), default=None
List of labels to index the confusion matrix. This may be used to
reorder or select a subset of labels. If `None` is given, those
that appear at least once in `y_true` or `y_pred` are used in
sorted order.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
normalize : {'true', 'pred', 'all'}, default=None
Either to normalize the counts display in the matrix:
- if `'true'`, the confusion matrix is normalized over the true
conditions (e.g. rows);
- if `'pred'`, the confusion matrix is normalized over the
predicted conditions (e.g. columns);
- if `'all'`, the confusion matrix is normalized by the total
number of samples;
- if `None` (default), the confusion matrix will not be normalized.
display_labels : array-like of shape (n_classes,), default=None
Target names used for plotting. By default, `labels` will be used
if it is defined, otherwise the unique labels of `y_true` and
`y_pred` will be used.
include_values : bool, default=True
Includes values in confusion matrix.
xticks_rotation : {'vertical', 'horizontal'} or float, \
default='horizontal'
Rotation of xtick labels.
values_format : str, default=None
Format specification for values in confusion matrix. If `None`, the
format specification is 'd' or '.2g' whichever is shorter.
cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.
ax : matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
colorbar : bool, default=True
Whether or not to add a colorbar to the plot.
im_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
text_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.text` call.
.. versionadded:: 1.2
Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
See Also
--------
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
given the true and predicted labels.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import ConfusionMatrixDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> ConfusionMatrixDisplay.from_estimator(
... clf, X_test, y_test)
<...>
>>> plt.show()
For a detailed example of using a confusion matrix to evaluate a
Support Vector Classifier, please see
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
"""
method_name = f"{cls.__name__}.from_estimator"
check_matplotlib_support(method_name)
if not is_classifier(estimator):
raise ValueError(f"{method_name} only supports classifiers")
y_pred = estimator.predict(X)
return cls.from_predictions(
y,
y_pred,
sample_weight=sample_weight,
labels=labels,
normalize=normalize,
display_labels=display_labels,
include_values=include_values,
cmap=cmap,
ax=ax,
xticks_rotation=xticks_rotation,
values_format=values_format,
colorbar=colorbar,
im_kw=im_kw,
text_kw=text_kw,
)
@classmethod
def from_predictions(
cls,
y_true,
y_pred,
*,
labels=None,
sample_weight=None,
normalize=None,
display_labels=None,
include_values=True,
xticks_rotation="horizontal",
values_format=None,
cmap="viridis",
ax=None,
colorbar=True,
im_kw=None,
text_kw=None,
):
"""Plot Confusion Matrix given true and predicted labels.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <confusion_matrix>`.
.. versionadded:: 1.0
Parameters
----------
y_true : array-like of shape (n_samples,)
True labels.
y_pred : array-like of shape (n_samples,)
The predicted labels given by the method `predict` of an
classifier.
labels : array-like of shape (n_classes,), default=None
List of labels to index the confusion matrix. This may be used to
reorder or select a subset of labels. If `None` is given, those
that appear at least once in `y_true` or `y_pred` are used in
sorted order.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
normalize : {'true', 'pred', 'all'}, default=None
Either to normalize the counts display in the matrix:
- if `'true'`, the confusion matrix is normalized over the true
conditions (e.g. rows);
- if `'pred'`, the confusion matrix is normalized over the
predicted conditions (e.g. columns);
- if `'all'`, the confusion matrix is normalized by the total
number of samples;
- if `None` (default), the confusion matrix will not be normalized.
display_labels : array-like of shape (n_classes,), default=None
Target names used for plotting. By default, `labels` will be used
if it is defined, otherwise the unique labels of `y_true` and
`y_pred` will be used.
include_values : bool, default=True
Includes values in confusion matrix.
xticks_rotation : {'vertical', 'horizontal'} or float, \
default='horizontal'
Rotation of xtick labels.
values_format : str, default=None
Format specification for values in confusion matrix. If `None`, the
format specification is 'd' or '.2g' whichever is shorter.
cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.
ax : matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
colorbar : bool, default=True
Whether or not to add a colorbar to the plot.
im_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
text_kw : dict, default=None
Dict with keywords passed to `matplotlib.pyplot.text` call.
.. versionadded:: 1.2
Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
See Also
--------
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
given an estimator, the data, and the label.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import ConfusionMatrixDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> y_pred = clf.predict(X_test)
>>> ConfusionMatrixDisplay.from_predictions(
... y_test, y_pred)
<...>
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")
if display_labels is None:
if labels is None:
display_labels = unique_labels(y_true, y_pred)
else:
display_labels = labels
cm = confusion_matrix(
y_true,
y_pred,
sample_weight=sample_weight,
labels=labels,
normalize=normalize,
)
disp = cls(confusion_matrix=cm, display_labels=display_labels)
return disp.plot(
include_values=include_values,
cmap=cmap,
ax=ax,
xticks_rotation=xticks_rotation,
values_format=values_format,
colorbar=colorbar,
im_kw=im_kw,
text_kw=text_kw,
)

View File

@@ -0,0 +1,371 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import scipy as sp
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
from .._ranking import det_curve
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
"""Detection Error Tradeoff (DET) curve visualization.
It is recommended to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a
visualizer. All parameters are stored as attributes.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <det_curve>`.
.. versionadded:: 0.24
Parameters
----------
fpr : ndarray
False positive rate.
fnr : ndarray
False negative rate.
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.
pos_label : int, float, bool or str, default=None
The label of the positive class.
Attributes
----------
line_ : matplotlib Artist
DET Curve.
ax_ : matplotlib Axes
Axes with DET Curve.
figure_ : matplotlib Figure
Figure containing the curve.
See Also
--------
det_curve : Compute error rates for different probability thresholds.
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
some data.
DetCurveDisplay.from_predictions : Plot DET curve given the true and
predicted labels.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import det_curve, DetCurveDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(n_samples=1000, random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, test_size=0.4, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> y_pred = clf.decision_function(X_test)
>>> fpr, fnr, _ = det_curve(y_test, y_pred)
>>> display = DetCurveDisplay(
... fpr=fpr, fnr=fnr, estimator_name="SVC"
... )
>>> display.plot()
<...>
>>> plt.show()
"""
def __init__(self, *, fpr, fnr, estimator_name=None, pos_label=None):
self.fpr = fpr
self.fnr = fnr
self.estimator_name = estimator_name
self.pos_label = pos_label
@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
sample_weight=None,
drop_intermediate=True,
response_method="auto",
pos_label=None,
name=None,
ax=None,
**kwargs,
):
"""Plot DET curve given an estimator and data.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <det_curve>`.
.. versionadded:: 1.0
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop thresholds where true positives (tp) do not change
from the previous or subsequent threshold. All points with the same
tp value have the same `fnr` and thus same y coordinate.
.. versionadded:: 1.7
response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the predicted target response. If set
to 'auto', :term:`predict_proba` is tried first and if it does not
exist :term:`decision_function` is tried next.
pos_label : int, float, bool or str, default=None
The label of the positive class. When `pos_label=None`, if `y_true`
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
error will be raised.
name : str, default=None
Name of DET curve for labeling. If `None`, use the name of the
estimator.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.
Returns
-------
display : :class:`~sklearn.metrics.DetCurveDisplay`
Object that stores computed values.
See Also
--------
det_curve : Compute error rates for different probability thresholds.
DetCurveDisplay.from_predictions : Plot DET curve given the true and
predicted labels.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import DetCurveDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(n_samples=1000, random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, test_size=0.4, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> DetCurveDisplay.from_estimator(
... clf, X_test, y_test)
<...>
>>> plt.show()
"""
y_pred, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
y,
response_method=response_method,
pos_label=pos_label,
name=name,
)
return cls.from_predictions(
y_true=y,
y_pred=y_pred,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
name=name,
ax=ax,
pos_label=pos_label,
**kwargs,
)
@classmethod
def from_predictions(
cls,
y_true,
y_pred,
*,
sample_weight=None,
drop_intermediate=True,
pos_label=None,
name=None,
ax=None,
**kwargs,
):
"""Plot the DET curve given the true and predicted labels.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the
:ref:`Model Evaluation Guide <det_curve>`.
.. versionadded:: 1.0
Parameters
----------
y_true : array-like of shape (n_samples,)
True labels.
y_pred : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by `decision_function` on some classifiers).
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop thresholds where true positives (tp) do not change
from the previous or subsequent threshold. All points with the same
tp value have the same `fnr` and thus same y coordinate.
.. versionadded:: 1.7
pos_label : int, float, bool or str, default=None
The label of the positive class. When `pos_label=None`, if `y_true`
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
error will be raised.
name : str, default=None
Name of DET curve for labeling. If `None`, name will be set to
`"Classifier"`.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.
Returns
-------
display : :class:`~sklearn.metrics.DetCurveDisplay`
Object that stores computed values.
See Also
--------
det_curve : Compute error rates for different probability thresholds.
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
some data.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import DetCurveDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(n_samples=1000, random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, test_size=0.4, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> y_pred = clf.decision_function(X_test)
>>> DetCurveDisplay.from_predictions(
... y_test, y_pred)
<...>
>>> plt.show()
"""
pos_label_validated, name = cls._validate_from_predictions_params(
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
)
fpr, fnr, _ = det_curve(
y_true,
y_pred,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)
viz = cls(
fpr=fpr,
fnr=fnr,
estimator_name=name,
pos_label=pos_label_validated,
)
return viz.plot(ax=ax, name=name, **kwargs)
def plot(self, ax=None, *, name=None, **kwargs):
"""Plot visualization.
Parameters
----------
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str, default=None
Name of DET curve for labeling. If `None`, use `estimator_name` if
it is not `None`, otherwise no labeling is shown.
**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.
Returns
-------
display : :class:`~sklearn.metrics.DetCurveDisplay`
Object that stores computed values.
"""
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
line_kwargs = {} if name is None else {"label": name}
line_kwargs.update(**kwargs)
# We have the following bounds:
# sp.stats.norm.ppf(0.0) = -np.inf
# sp.stats.norm.ppf(1.0) = np.inf
# We therefore clip to eps and 1 - eps to not provide infinity to matplotlib.
eps = np.finfo(self.fpr.dtype).eps
self.fpr = self.fpr.clip(eps, 1 - eps)
self.fnr = self.fnr.clip(eps, 1 - eps)
(self.line_,) = self.ax_.plot(
sp.stats.norm.ppf(self.fpr),
sp.stats.norm.ppf(self.fnr),
**line_kwargs,
)
info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
xlabel = "False Positive Rate" + info_pos_label
ylabel = "False Negative Rate" + info_pos_label
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
if "label" in line_kwargs:
self.ax_.legend(loc="lower right")
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
tick_locations = sp.stats.norm.ppf(ticks)
tick_labels = [
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
for s in ticks
]
self.ax_.set_xticks(tick_locations)
self.ax_.set_xticklabels(tick_labels)
self.ax_.set_xlim(-3, 3)
self.ax_.set_yticks(tick_locations)
self.ax_.set_yticklabels(tick_labels)
self.ax_.set_ylim(-3, 3)
return self

View File

@@ -0,0 +1,555 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from collections import Counter
from ...utils._plotting import (
_BinaryClassifierCurveDisplayMixin,
_despine,
_validate_style_kwargs,
)
from .._ranking import average_precision_score, precision_recall_curve
class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
"""Precision Recall visualization.
It is recommended to use
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create
a :class:`~sklearn.metrics.PrecisionRecallDisplay`. All parameters are
stored as attributes.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
precision : ndarray
Precision values.
recall : ndarray
Recall values.
average_precision : float, default=None
Average precision. If None, the average precision is not shown.
estimator_name : str, default=None
Name of estimator. If None, then the estimator name is not shown.
pos_label : int, float, bool or str, default=None
The class considered as the positive class. If None, the class will not
be shown in the legend.
.. versionadded:: 0.24
prevalence_pos_label : float, default=None
The prevalence of the positive label. It is used for plotting the
chance level line. If None, the chance level line will not be plotted
even if `plot_chance_level` is set to True when plotting.
.. versionadded:: 1.3
Attributes
----------
line_ : matplotlib Artist
Precision recall curve.
chance_level_ : matplotlib Artist or None
The chance level line. It is `None` if the chance level is not plotted.
.. versionadded:: 1.3
ax_ : matplotlib Axes
Axes with precision recall curve.
figure_ : matplotlib Figure
Figure containing the curve.
See Also
--------
precision_recall_curve : Compute precision-recall pairs for different
probability thresholds.
PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given
a binary classifier.
PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve
using predictions from a binary classifier.
Notes
-----
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) in
scikit-learn is computed without any interpolation. To be consistent with
this metric, the precision-recall curve is plotted without any
interpolation as well (step-wise style).
You can change this style by passing the keyword argument
`drawstyle="default"` in :meth:`plot`, :meth:`from_estimator`, or
:meth:`from_predictions`. However, the curve will not be strictly
consistent with the reported average precision.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import (precision_recall_curve,
... PrecisionRecallDisplay)
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> predictions = clf.predict(X_test)
>>> precision, recall, _ = precision_recall_curve(y_test, predictions)
>>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
>>> disp.plot()
<...>
>>> plt.show()
"""
def __init__(
self,
precision,
recall,
*,
average_precision=None,
estimator_name=None,
pos_label=None,
prevalence_pos_label=None,
):
self.estimator_name = estimator_name
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.pos_label = pos_label
self.prevalence_pos_label = prevalence_pos_label
def plot(
self,
ax=None,
*,
name=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
**kwargs,
):
"""Plot visualization.
Extra keyword arguments will be passed to matplotlib's `plot`.
Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str, default=None
Name of precision recall curve for labeling. If `None`, use
`estimator_name` if not `None`, otherwise no labeling is shown.
plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
Notes
-----
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
in scikit-learn is computed without any interpolation. To be consistent
with this metric, the precision-recall curve is plotted without any
interpolation as well (step-wise style).
You can change this style by passing the keyword argument
`drawstyle="default"`. However, the curve will not be strictly
consistent with the reported average precision.
"""
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
default_line_kwargs = {"drawstyle": "steps-post"}
if self.average_precision is not None and name is not None:
default_line_kwargs["label"] = (
f"{name} (AP = {self.average_precision:0.2f})"
)
elif self.average_precision is not None:
default_line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
elif name is not None:
default_line_kwargs["label"] = name
line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs)
(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)
info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
xlabel = "Recall" + info_pos_label
ylabel = "Precision" + info_pos_label
self.ax_.set(
xlabel=xlabel,
xlim=(-0.01, 1.01),
ylabel=ylabel,
ylim=(-0.01, 1.01),
aspect="equal",
)
if plot_chance_level:
if self.prevalence_pos_label is None:
raise ValueError(
"You must provide prevalence_pos_label when constructing the "
"PrecisionRecallDisplay object in order to plot the chance "
"level line. Alternatively, you may use "
"PrecisionRecallDisplay.from_estimator or "
"PrecisionRecallDisplay.from_predictions "
"to automatically set prevalence_pos_label"
)
default_chance_level_line_kw = {
"label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})",
"color": "k",
"linestyle": "--",
}
if chance_level_kw is None:
chance_level_kw = {}
chance_level_line_kw = _validate_style_kwargs(
default_chance_level_line_kw, chance_level_kw
)
(self.chance_level_,) = self.ax_.plot(
(0, 1),
(self.prevalence_pos_label, self.prevalence_pos_label),
**chance_level_line_kw,
)
else:
self.chance_level_ = None
if despine:
_despine(self.ax_)
if "label" in line_kwargs or plot_chance_level:
self.ax_.legend(loc="lower left")
return self
@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
sample_weight=None,
drop_intermediate=False,
response_method="auto",
pos_label=None,
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
**kwargs,
):
"""Plot precision-recall curve given an estimator and some data.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to
create lighter precision-recall curves.
.. versionadded:: 1.3
response_method : {'predict_proba', 'decision_function', 'auto'}, \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing the
precision and recall metrics. By default, `estimators.classes_[1]`
is considered as the positive class.
name : str, default=None
Name for labeling curve. If `None`, no name is used.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
See Also
--------
PrecisionRecallDisplay.from_predictions : Plot precision-recall curve
using estimated probabilities or output of decision function.
Notes
-----
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
in scikit-learn is computed without any interpolation. To be consistent
with this metric, the precision-recall curve is plotted without any
interpolation as well (step-wise style).
You can change this style by passing the keyword argument
`drawstyle="default"`. However, the curve will not be strictly
consistent with the reported average precision.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import PrecisionRecallDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = LogisticRegression()
>>> clf.fit(X_train, y_train)
LogisticRegression()
>>> PrecisionRecallDisplay.from_estimator(
... clf, X_test, y_test)
<...>
>>> plt.show()
"""
y_pred, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
y,
response_method=response_method,
pos_label=pos_label,
name=name,
)
return cls.from_predictions(
y,
y_pred,
sample_weight=sample_weight,
name=name,
pos_label=pos_label,
drop_intermediate=drop_intermediate,
ax=ax,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
despine=despine,
**kwargs,
)
@classmethod
def from_predictions(
cls,
y_true,
y_pred,
*,
sample_weight=None,
drop_intermediate=False,
pos_label=None,
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
**kwargs,
):
"""Plot precision-recall curve given binary class predictions.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels.
y_pred : array-like of shape (n_samples,)
Estimated probabilities or output of decision function.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to
create lighter precision-recall curves.
.. versionadded:: 1.3
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing the
precision and recall metrics.
name : str, default=None
Name for labeling curve. If `None`, name will be set to
`"Classifier"`.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
See Also
--------
PrecisionRecallDisplay.from_estimator : Plot precision-recall curve
using an estimator.
Notes
-----
The average precision (cf. :func:`~sklearn.metrics.average_precision_score`)
in scikit-learn is computed without any interpolation. To be consistent
with this metric, the precision-recall curve is plotted without any
interpolation as well (step-wise style).
You can change this style by passing the keyword argument
`drawstyle="default"`. However, the curve will not be strictly
consistent with the reported average precision.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import PrecisionRecallDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = LogisticRegression()
>>> clf.fit(X_train, y_train)
LogisticRegression()
>>> y_pred = clf.predict_proba(X_test)[:, 1]
>>> PrecisionRecallDisplay.from_predictions(
... y_test, y_pred)
<...>
>>> plt.show()
"""
pos_label, name = cls._validate_from_predictions_params(
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
)
precision, recall, _ = precision_recall_curve(
y_true,
y_pred,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)
average_precision = average_precision_score(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
)
class_count = Counter(y_true)
prevalence_pos_label = class_count[pos_label] / sum(class_count.values())
viz = cls(
precision=precision,
recall=recall,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
prevalence_pos_label=prevalence_pos_label,
)
return viz.plot(
ax=ax,
name=name,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
despine=despine,
**kwargs,
)

View File

@@ -0,0 +1,413 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import numbers
import numpy as np
from ...utils import _safe_indexing, check_random_state
from ...utils._optional_dependencies import check_matplotlib_support
from ...utils._plotting import _validate_style_kwargs
class PredictionErrorDisplay:
"""Visualization of the prediction error of a regression model.
This tool can display "residuals vs predicted" or "actual vs predicted"
using scatter plots to qualitatively assess the behavior of a regressor,
preferably on held-out data points.
See the details in the docstrings of
:func:`~sklearn.metrics.PredictionErrorDisplay.from_estimator` or
:func:`~sklearn.metrics.PredictionErrorDisplay.from_predictions` to
create a visualizer. All parameters are stored as attributes.
For general information regarding `scikit-learn` visualization tools, read
more in the :ref:`Visualization Guide <visualizations>`.
For details regarding interpreting these plots, refer to the
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
.. versionadded:: 1.2
Parameters
----------
y_true : ndarray of shape (n_samples,)
True values.
y_pred : ndarray of shape (n_samples,)
Prediction values.
Attributes
----------
line_ : matplotlib Artist
Optimal line representing `y_true == y_pred`. Therefore, it is a
diagonal line for `kind="predictions"` and a horizontal line for
`kind="residuals"`.
errors_lines_ : matplotlib Artist or None
Residual lines. If `with_errors=False`, then it is set to `None`.
scatter_ : matplotlib Artist
Scatter data points.
ax_ : matplotlib Axes
Axes with the different matplotlib axis.
figure_ : matplotlib Figure
Figure containing the scatter and lines.
See Also
--------
PredictionErrorDisplay.from_estimator : Prediction error visualization
given an estimator and some data.
PredictionErrorDisplay.from_predictions : Prediction error visualization
given the true and predicted targets.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.linear_model import Ridge
>>> from sklearn.metrics import PredictionErrorDisplay
>>> X, y = load_diabetes(return_X_y=True)
>>> ridge = Ridge().fit(X, y)
>>> y_pred = ridge.predict(X)
>>> display = PredictionErrorDisplay(y_true=y, y_pred=y_pred)
>>> display.plot()
<...>
>>> plt.show()
"""
def __init__(self, *, y_true, y_pred):
self.y_true = y_true
self.y_pred = y_pred
def plot(
self,
ax=None,
*,
kind="residual_vs_predicted",
scatter_kwargs=None,
line_kwargs=None,
):
"""Plot visualization.
Extra keyword arguments will be passed to matplotlib's ``plot``.
Parameters
----------
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
default="residual_vs_predicted"
The type of plot to draw:
- "actual_vs_predicted" draws the observed values (y-axis) vs.
the predicted values (x-axis).
- "residual_vs_predicted" draws the residuals, i.e. difference
between observed and predicted values, (y-axis) vs. the predicted
values (x-axis).
scatter_kwargs : dict, default=None
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
call.
line_kwargs : dict, default=None
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
call to draw the optimal line.
Returns
-------
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
Object that stores computed values.
"""
check_matplotlib_support(f"{self.__class__.__name__}.plot")
expected_kind = ("actual_vs_predicted", "residual_vs_predicted")
if kind not in expected_kind:
raise ValueError(
f"`kind` must be one of {', '.join(expected_kind)}. "
f"Got {kind!r} instead."
)
import matplotlib.pyplot as plt
if scatter_kwargs is None:
scatter_kwargs = {}
if line_kwargs is None:
line_kwargs = {}
default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8}
default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"}
scatter_kwargs = _validate_style_kwargs(default_scatter_kwargs, scatter_kwargs)
line_kwargs = _validate_style_kwargs(default_line_kwargs, line_kwargs)
scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs}
line_kwargs = {**default_line_kwargs, **line_kwargs}
if ax is None:
_, ax = plt.subplots()
if kind == "actual_vs_predicted":
max_value = max(np.max(self.y_true), np.max(self.y_pred))
min_value = min(np.min(self.y_true), np.min(self.y_pred))
self.line_ = ax.plot(
[min_value, max_value], [min_value, max_value], **line_kwargs
)[0]
x_data, y_data = self.y_pred, self.y_true
xlabel, ylabel = "Predicted values", "Actual values"
self.scatter_ = ax.scatter(x_data, y_data, **scatter_kwargs)
# force to have a squared axis
ax.set_aspect("equal", adjustable="datalim")
ax.set_xticks(np.linspace(min_value, max_value, num=5))
ax.set_yticks(np.linspace(min_value, max_value, num=5))
else: # kind == "residual_vs_predicted"
self.line_ = ax.plot(
[np.min(self.y_pred), np.max(self.y_pred)],
[0, 0],
**line_kwargs,
)[0]
self.scatter_ = ax.scatter(
self.y_pred, self.y_true - self.y_pred, **scatter_kwargs
)
xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)"
ax.set(xlabel=xlabel, ylabel=ylabel)
self.ax_ = ax
self.figure_ = ax.figure
return self
@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
kind="residual_vs_predicted",
subsample=1_000,
random_state=None,
ax=None,
scatter_kwargs=None,
line_kwargs=None,
):
"""Plot the prediction error given a regressor and some data.
For general information regarding `scikit-learn` visualization tools,
read more in the :ref:`Visualization Guide <visualizations>`.
For details regarding interpreting these plots, refer to the
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
.. versionadded:: 1.2
Parameters
----------
estimator : estimator instance
Fitted regressor or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a regressor.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
default="residual_vs_predicted"
The type of plot to draw:
- "actual_vs_predicted" draws the observed values (y-axis) vs.
the predicted values (x-axis).
- "residual_vs_predicted" draws the residuals, i.e. difference
between observed and predicted values, (y-axis) vs. the predicted
values (x-axis).
subsample : float, int or None, default=1_000
Sampling the samples to be shown on the scatter plot. If `float`,
it should be between 0 and 1 and represents the proportion of the
original dataset. If `int`, it represents the number of samples
display on the scatter plot. If `None`, no subsampling will be
applied. by default, 1000 samples or less will be displayed.
random_state : int or RandomState, default=None
Controls the randomness when `subsample` is not `None`.
See :term:`Glossary <random_state>` for details.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
scatter_kwargs : dict, default=None
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
call.
line_kwargs : dict, default=None
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
call to draw the optimal line.
Returns
-------
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
Object that stores the computed values.
See Also
--------
PredictionErrorDisplay : Prediction error visualization for regression.
PredictionErrorDisplay.from_predictions : Prediction error visualization
given the true and predicted targets.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.linear_model import Ridge
>>> from sklearn.metrics import PredictionErrorDisplay
>>> X, y = load_diabetes(return_X_y=True)
>>> ridge = Ridge().fit(X, y)
>>> disp = PredictionErrorDisplay.from_estimator(ridge, X, y)
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_estimator")
y_pred = estimator.predict(X)
return cls.from_predictions(
y_true=y,
y_pred=y_pred,
kind=kind,
subsample=subsample,
random_state=random_state,
ax=ax,
scatter_kwargs=scatter_kwargs,
line_kwargs=line_kwargs,
)
@classmethod
def from_predictions(
cls,
y_true,
y_pred,
*,
kind="residual_vs_predicted",
subsample=1_000,
random_state=None,
ax=None,
scatter_kwargs=None,
line_kwargs=None,
):
"""Plot the prediction error given the true and predicted targets.
For general information regarding `scikit-learn` visualization tools,
read more in the :ref:`Visualization Guide <visualizations>`.
For details regarding interpreting these plots, refer to the
:ref:`Model Evaluation Guide <visualization_regression_evaluation>`.
.. versionadded:: 1.2
Parameters
----------
y_true : array-like of shape (n_samples,)
True target values.
y_pred : array-like of shape (n_samples,)
Predicted target values.
kind : {"actual_vs_predicted", "residual_vs_predicted"}, \
default="residual_vs_predicted"
The type of plot to draw:
- "actual_vs_predicted" draws the observed values (y-axis) vs.
the predicted values (x-axis).
- "residual_vs_predicted" draws the residuals, i.e. difference
between observed and predicted values, (y-axis) vs. the predicted
values (x-axis).
subsample : float, int or None, default=1_000
Sampling the samples to be shown on the scatter plot. If `float`,
it should be between 0 and 1 and represents the proportion of the
original dataset. If `int`, it represents the number of samples
display on the scatter plot. If `None`, no subsampling will be
applied. by default, 1000 samples or less will be displayed.
random_state : int or RandomState, default=None
Controls the randomness when `subsample` is not `None`.
See :term:`Glossary <random_state>` for details.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
scatter_kwargs : dict, default=None
Dictionary with keywords passed to the `matplotlib.pyplot.scatter`
call.
line_kwargs : dict, default=None
Dictionary with keyword passed to the `matplotlib.pyplot.plot`
call to draw the optimal line.
Returns
-------
display : :class:`~sklearn.metrics.PredictionErrorDisplay`
Object that stores the computed values.
See Also
--------
PredictionErrorDisplay : Prediction error visualization for regression.
PredictionErrorDisplay.from_estimator : Prediction error visualization
given an estimator and some data.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import load_diabetes
>>> from sklearn.linear_model import Ridge
>>> from sklearn.metrics import PredictionErrorDisplay
>>> X, y = load_diabetes(return_X_y=True)
>>> ridge = Ridge().fit(X, y)
>>> y_pred = ridge.predict(X)
>>> disp = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred)
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")
random_state = check_random_state(random_state)
n_samples = len(y_true)
if isinstance(subsample, numbers.Integral):
if subsample <= 0:
raise ValueError(
f"When an integer, subsample={subsample} should be positive."
)
elif isinstance(subsample, numbers.Real):
if subsample <= 0 or subsample >= 1:
raise ValueError(
f"When a floating-point, subsample={subsample} should"
" be in the (0, 1) range."
)
subsample = int(n_samples * subsample)
if subsample is not None and subsample < n_samples:
indices = random_state.choice(np.arange(n_samples), size=subsample)
y_true = _safe_indexing(y_true, indices, axis=0)
y_pred = _safe_indexing(y_pred, indices, axis=0)
viz = cls(
y_true=y_true,
y_pred=y_pred,
)
return viz.plot(
ax=ax,
kind=kind,
scatter_kwargs=scatter_kwargs,
line_kwargs=line_kwargs,
)

View File

@@ -0,0 +1,795 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
import numpy as np
from ...utils import _safe_indexing
from ...utils._plotting import (
_BinaryClassifierCurveDisplayMixin,
_check_param_lengths,
_convert_to_list_leaving_none,
_deprecate_estimator_name,
_despine,
_validate_style_kwargs,
)
from ...utils._response import _get_response_values_binary
from .._ranking import auc, roc_curve
class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):
"""ROC Curve visualization.
It is recommended to use
:func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or
:func:`~sklearn.metrics.RocCurveDisplay.from_predictions` or
:func:`~sklearn.metrics.RocCurveDisplay.from_cv_results` to create
a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are
stored as attributes.
For general information regarding `scikit-learn` visualization tools, see
the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <roc_metrics>`.
Parameters
----------
fpr : ndarray or list of ndarrays
False positive rates. Each ndarray should contain values for a single curve.
If plotting multiple curves, list should be of same length as `tpr`.
.. versionchanged:: 1.7
Now accepts a list for plotting multiple curves.
tpr : ndarray or list of ndarrays
True positive rates. Each ndarray should contain values for a single curve.
If plotting multiple curves, list should be of same length as `fpr`.
.. versionchanged:: 1.7
Now accepts a list for plotting multiple curves.
roc_auc : float or list of floats, default=None
Area under ROC curve, used for labeling each curve in the legend.
If plotting multiple curves, should be a list of the same length as `fpr`
and `tpr`. If `None`, ROC AUC scores are not shown in the legend.
.. versionchanged:: 1.7
Now accepts a list for plotting multiple curves.
name : str or list of str, default=None
Name for labeling legend entries. The number of legend entries is determined
by the `curve_kwargs` passed to `plot`, and is not affected by `name`.
To label each curve, provide a list of strings. To avoid labeling
individual curves that have the same appearance, this cannot be used in
conjunction with `curve_kwargs` being a dictionary or None. If a
string is provided, it will be used to either label the single legend entry
or if there are multiple legend entries, label each individual curve with
the same name. If still `None`, no name is shown in the legend.
.. versionadded:: 1.7
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing the roc auc
metrics. By default, `estimators.classes_[1]` is considered
as the positive class.
.. versionadded:: 0.24
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.
.. deprecated:: 1.7
`estimator_name` is deprecated and will be removed in 1.9. Use `name`
instead.
Attributes
----------
line_ : matplotlib Artist or list of matplotlib Artists
ROC Curves.
.. versionchanged:: 1.7
This attribute can now be a list of Artists, for when multiple curves
are plotted.
chance_level_ : matplotlib Artist or None
The chance level line. It is `None` if the chance level is not plotted.
.. versionadded:: 1.3
ax_ : matplotlib Axes
Axes with ROC Curve.
figure_ : matplotlib Figure
Figure containing the curve.
See Also
--------
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
(ROC) curve given an estimator and some data.
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
(ROC) curve given the true and predicted values.
roc_auc_score : Compute the area under the ROC curve.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from sklearn import metrics
>>> y_true = np.array([0, 0, 1, 1])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score)
>>> roc_auc = metrics.auc(fpr, tpr)
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
... name='example estimator')
>>> display.plot()
<...>
>>> plt.show()
"""
def __init__(
self,
*,
fpr,
tpr,
roc_auc=None,
name=None,
pos_label=None,
estimator_name="deprecated",
):
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.name = _deprecate_estimator_name(estimator_name, name, "1.7")
self.pos_label = pos_label
def _validate_plot_params(self, *, ax, name):
self.ax_, self.figure_, name = super()._validate_plot_params(ax=ax, name=name)
fpr = _convert_to_list_leaving_none(self.fpr)
tpr = _convert_to_list_leaving_none(self.tpr)
roc_auc = _convert_to_list_leaving_none(self.roc_auc)
name = _convert_to_list_leaving_none(name)
optional = {"self.roc_auc": roc_auc}
if isinstance(name, list) and len(name) != 1:
optional.update({"'name' (or self.name)": name})
_check_param_lengths(
required={"self.fpr": fpr, "self.tpr": tpr},
optional=optional,
class_name="RocCurveDisplay",
)
return fpr, tpr, roc_auc, name
def plot(
self,
ax=None,
*,
name=None,
curve_kwargs=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
**kwargs,
):
"""Plot visualization.
Parameters
----------
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str or list of str, default=None
Name for labeling legend entries. The number of legend entries
is determined by `curve_kwargs`, and is not affected by `name`.
To label each curve, provide a list of strings. To avoid labeling
individual curves that have the same appearance, this cannot be used in
conjunction with `curve_kwargs` being a dictionary or None. If a
string is provided, it will be used to either label the single legend entry
or if there are multiple legend entries, label each individual curve with
the same name. If `None`, set to `name` provided at `RocCurveDisplay`
initialization. If still `None`, no name is shown in the legend.
.. versionadded:: 1.7
curve_kwargs : dict or list of dict, default=None
Keywords arguments to be passed to matplotlib's `plot` function
to draw individual ROC curves. For single curve plotting, should be
a dictionary. For multi-curve plotting, if a list is provided the
parameters are applied to the ROC curves of each CV fold
sequentially and a legend entry is added for each curve.
If a single dictionary is provided, the same parameters are applied
to all ROC curves and a single legend entry for all curves is added,
labeled with the mean ROC AUC score.
.. versionadded:: 1.7
plot_chance_level : bool, default=False
Whether to plot the chance level.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
.. deprecated:: 1.7
kwargs is deprecated and will be removed in 1.9. Pass matplotlib
arguments to `curve_kwargs` as a dictionary instead.
Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
Object that stores computed values.
"""
fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name)
n_curves = len(fpr)
if not isinstance(curve_kwargs, list) and n_curves > 1:
if roc_auc:
legend_metric = {"mean": np.mean(roc_auc), "std": np.std(roc_auc)}
else:
legend_metric = {"mean": None, "std": None}
else:
roc_auc = roc_auc if roc_auc is not None else [None] * n_curves
legend_metric = {"metric": roc_auc}
curve_kwargs = self._validate_curve_kwargs(
n_curves,
name,
legend_metric,
"AUC",
curve_kwargs=curve_kwargs,
**kwargs,
)
default_chance_level_line_kw = {
"label": "Chance level (AUC = 0.5)",
"color": "k",
"linestyle": "--",
}
if chance_level_kw is None:
chance_level_kw = {}
chance_level_kw = _validate_style_kwargs(
default_chance_level_line_kw, chance_level_kw
)
self.line_ = []
for fpr, tpr, line_kw in zip(fpr, tpr, curve_kwargs):
self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw))
# Return single artist if only one curve is plotted
if len(self.line_) == 1:
self.line_ = self.line_[0]
info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
xlabel = "False Positive Rate" + info_pos_label
ylabel = "True Positive Rate" + info_pos_label
self.ax_.set(
xlabel=xlabel,
xlim=(-0.01, 1.01),
ylabel=ylabel,
ylim=(-0.01, 1.01),
aspect="equal",
)
if plot_chance_level:
(self.chance_level_,) = self.ax_.plot((0, 1), (0, 1), **chance_level_kw)
else:
self.chance_level_ = None
if despine:
_despine(self.ax_)
if curve_kwargs[0].get("label") is not None or (
plot_chance_level and chance_level_kw.get("label") is not None
):
self.ax_.legend(loc="lower right")
return self
@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
sample_weight=None,
drop_intermediate=True,
response_method="auto",
pos_label=None,
name=None,
ax=None,
curve_kwargs=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
**kwargs,
):
"""Create a ROC Curve display from an estimator.
For general information regarding `scikit-learn` visualization tools,
see the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <roc_metrics>`.
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop thresholds where the resulting point is collinear
with its neighbors in ROC space. This has no effect on the ROC AUC
or visual shape of the curve, but reduces the number of plotted
points.
response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing the ROC AUC.
By default, `estimators.classes_[1]` is considered
as the positive class.
name : str, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
curve_kwargs : dict, default=None
Keywords arguments to be passed to matplotlib's `plot` function.
.. versionadded:: 1.7
plot_chance_level : bool, default=False
Whether to plot the chance level.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
.. deprecated:: 1.7
kwargs is deprecated and will be removed in 1.9. Pass matplotlib
arguments to `curve_kwargs` as a dictionary instead.
Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
The ROC Curve display.
See Also
--------
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_predictions : ROC Curve visualization given the
probabilities of scores of a classifier.
roc_auc_score : Compute the area under the ROC curve.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import RocCurveDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> RocCurveDisplay.from_estimator(
... clf, X_test, y_test)
<...>
>>> plt.show()
"""
y_score, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
y,
response_method=response_method,
pos_label=pos_label,
name=name,
)
return cls.from_predictions(
y_true=y,
y_score=y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
name=name,
ax=ax,
curve_kwargs=curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
despine=despine,
**kwargs,
)
@classmethod
def from_predictions(
cls,
y_true,
y_score=None,
*,
sample_weight=None,
drop_intermediate=True,
pos_label=None,
name=None,
ax=None,
curve_kwargs=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
y_pred="deprecated",
**kwargs,
):
"""Plot ROC curve given the true and predicted values.
For general information regarding `scikit-learn` visualization tools,
see the :ref:`Visualization Guide <visualizations>`.
For guidance on interpreting these plots, refer to the :ref:`Model
Evaluation Guide <roc_metrics>`.
.. versionadded:: 1.0
Parameters
----------
y_true : array-like of shape (n_samples,)
True labels.
y_score : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by “decision_function” on some classifiers).
.. versionadded:: 1.7
`y_pred` has been renamed to `y_score`.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop thresholds where the resulting point is collinear
with its neighbors in ROC space. This has no effect on the ROC AUC
or visual shape of the curve, but reduces the number of plotted
points.
pos_label : int, float, bool or str, default=None
The label of the positive class when computing the ROC AUC.
When `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label`
is set to 1, otherwise an error will be raised.
name : str, default=None
Name of ROC curve for legend labeling. If `None`, name will be set to
`"Classifier"`.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
curve_kwargs : dict, default=None
Keywords arguments to be passed to matplotlib's `plot` function.
.. versionadded:: 1.7
plot_chance_level : bool, default=False
Whether to plot the chance level.
.. versionadded:: 1.3
chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
.. versionadded:: 1.3
despine : bool, default=False
Whether to remove the top and right spines from the plot.
.. versionadded:: 1.6
y_pred : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by “decision_function” on some classifiers).
.. deprecated:: 1.7
`y_pred` is deprecated and will be removed in 1.9. Use
`y_score` instead.
**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.
.. deprecated:: 1.7
kwargs is deprecated and will be removed in 1.9. Pass matplotlib
arguments to `curve_kwargs` as a dictionary instead.
Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
Object that stores computed values.
See Also
--------
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_estimator : ROC Curve visualization given an
estimator and some data.
roc_auc_score : Compute the area under the ROC curve.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import RocCurveDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> y_score = clf.decision_function(X_test)
>>> RocCurveDisplay.from_predictions(y_test, y_score)
<...>
>>> plt.show()
"""
# TODO(1.9): remove after the end of the deprecation period of `y_pred`
if y_score is not None and not (
isinstance(y_pred, str) and y_pred == "deprecated"
):
raise ValueError(
"`y_pred` and `y_score` cannot be both specified. Please use `y_score`"
" only as `y_pred` is deprecated in 1.7 and will be removed in 1.9."
)
if not (isinstance(y_pred, str) and y_pred == "deprecated"):
warnings.warn(
(
"y_pred is deprecated in 1.7 and will be removed in 1.9. "
"Please use `y_score` instead."
),
FutureWarning,
)
y_score = y_pred
pos_label_validated, name = cls._validate_from_predictions_params(
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name
)
fpr, tpr, _ = roc_curve(
y_true,
y_score,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)
roc_auc = auc(fpr, tpr)
viz = cls(
fpr=fpr,
tpr=tpr,
roc_auc=roc_auc,
name=name,
pos_label=pos_label_validated,
)
return viz.plot(
ax=ax,
curve_kwargs=curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
despine=despine,
**kwargs,
)
@classmethod
def from_cv_results(
cls,
cv_results,
X,
y,
*,
sample_weight=None,
drop_intermediate=True,
response_method="auto",
pos_label=None,
ax=None,
name=None,
curve_kwargs=None,
plot_chance_level=False,
chance_level_kwargs=None,
despine=False,
):
"""Create a multi-fold ROC curve display given cross-validation results.
.. versionadded:: 1.7
Parameters
----------
cv_results : dict
Dictionary as returned by :func:`~sklearn.model_selection.cross_validate`
using `return_estimator=True` and `return_indices=True` (i.e., dictionary
should contain the keys "estimator" and "indices").
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.
response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
pos_label : int, float, bool or str, default=None
The class considered as the positive class when computing the ROC AUC
metrics. By default, `estimators.classes_[1]` is considered
as the positive class.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str or list of str, default=None
Name for labeling legend entries. The number of legend entries
is determined by `curve_kwargs`, and is not affected by `name`.
To label each curve, provide a list of strings. To avoid labeling
individual curves that have the same appearance, this cannot be used in
conjunction with `curve_kwargs` being a dictionary or None. If a
string is provided, it will be used to either label the single legend entry
or if there are multiple legend entries, label each individual curve with
the same name. If `None`, no name is shown in the legend.
curve_kwargs : dict or list of dict, default=None
Keywords arguments to be passed to matplotlib's `plot` function
to draw individual ROC curves. If a list is provided the
parameters are applied to the ROC curves of each CV fold
sequentially and a legend entry is added for each curve.
If a single dictionary is provided, the same parameters are applied
to all ROC curves and a single legend entry for all curves is added,
labeled with the mean ROC AUC score.
plot_chance_level : bool, default=False
Whether to plot the chance level.
chance_level_kwargs : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.
despine : bool, default=False
Whether to remove the top and right spines from the plot.
Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
The multi-fold ROC curve display.
See Also
--------
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_estimator : ROC Curve visualization given an
estimator and some data.
RocCurveDisplay.from_predictions : ROC Curve visualization given the
probabilities of scores of a classifier.
roc_auc_score : Compute the area under the ROC curve.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import RocCurveDisplay
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> clf = SVC(random_state=0)
>>> cv_results = cross_validate(
... clf, X, y, cv=3, return_estimator=True, return_indices=True)
>>> RocCurveDisplay.from_cv_results(cv_results, X, y)
<...>
>>> plt.show()
"""
pos_label_ = cls._validate_from_cv_results_params(
cv_results,
X,
y,
sample_weight=sample_weight,
pos_label=pos_label,
)
fpr_folds, tpr_folds, auc_folds = [], [], []
for estimator, test_indices in zip(
cv_results["estimator"], cv_results["indices"]["test"]
):
y_true = _safe_indexing(y, test_indices)
y_pred, _ = _get_response_values_binary(
estimator,
_safe_indexing(X, test_indices),
response_method=response_method,
pos_label=pos_label_,
)
sample_weight_fold = (
None
if sample_weight is None
else _safe_indexing(sample_weight, test_indices)
)
fpr, tpr, _ = roc_curve(
y_true,
y_pred,
pos_label=pos_label_,
sample_weight=sample_weight_fold,
drop_intermediate=drop_intermediate,
)
roc_auc = auc(fpr, tpr)
fpr_folds.append(fpr)
tpr_folds.append(tpr)
auc_folds.append(roc_auc)
viz = cls(
fpr=fpr_folds,
tpr=tpr_folds,
roc_auc=auc_folds,
name=name,
pos_label=pos_label_,
)
return viz.plot(
ax=ax,
curve_kwargs=curve_kwargs,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kwargs,
despine=despine,
)

View File

@@ -0,0 +1,292 @@
import numpy as np
import pytest
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.calibration import CalibrationDisplay
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_iris
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
ConfusionMatrixDisplay,
DetCurveDisplay,
PrecisionRecallDisplay,
PredictionErrorDisplay,
RocCurveDisplay,
)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
@pytest.fixture(scope="module")
def data():
return load_iris(return_X_y=True)
@pytest.fixture(scope="module")
def data_binary(data):
X, y = data
return X[y < 2], y[y < 2]
@pytest.mark.parametrize(
"Display",
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
)
def test_display_curve_error_classifier(pyplot, data, data_binary, Display):
"""Check that a proper error is raised when only binary classification is
supported."""
X, y = data
X_binary, y_binary = data_binary
clf = DecisionTreeClassifier().fit(X, y)
# Case 1: multiclass classifier with multiclass target
msg = "Expected 'estimator' to be a binary classifier. Got 3 classes instead."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X, y)
# Case 2: multiclass classifier with binary target
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X_binary, y_binary)
# Case 3: binary classifier with multiclass target
clf = DecisionTreeClassifier().fit(X_binary, y_binary)
msg = "The target y is not binary. Got multiclass type of target."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X, y)
@pytest.mark.parametrize(
"Display",
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
)
def test_display_curve_error_regression(pyplot, data_binary, Display):
"""Check that we raise an error with regressor."""
# Case 1: regressor
X, y = data_binary
regressor = DecisionTreeRegressor().fit(X, y)
msg = "Expected 'estimator' to be a binary classifier. Got DecisionTreeRegressor"
with pytest.raises(ValueError, match=msg):
Display.from_estimator(regressor, X, y)
# Case 2: regression target
classifier = DecisionTreeClassifier().fit(X, y)
# Force `y_true` to be seen as a regression problem
y = y + 0.5
msg = "The target y is not binary. Got continuous type of target."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y)
with pytest.raises(ValueError, match=msg):
Display.from_predictions(y, regressor.fit(X, y).predict(X))
@pytest.mark.parametrize(
"response_method, msg",
[
(
"predict_proba",
"MyClassifier has none of the following attributes: predict_proba.",
),
(
"decision_function",
"MyClassifier has none of the following attributes: decision_function.",
),
(
"auto",
(
"MyClassifier has none of the following attributes: predict_proba,"
" decision_function."
),
),
(
"bad_method",
"MyClassifier has none of the following attributes: bad_method.",
),
],
)
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_error_no_response(
pyplot,
data_binary,
response_method,
msg,
Display,
):
"""Check that a proper error is raised when the response method requested
is not defined for the given trained classifier."""
X, y = data_binary
class MyClassifier(ClassifierMixin, BaseEstimator):
def fit(self, X, y):
self.classes_ = [0, 1]
return self
clf = MyClassifier().fit(X, y)
with pytest.raises(AttributeError, match=msg):
Display.from_estimator(clf, X, y, response_method=response_method)
@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_display_curve_estimator_name_multiple_calls(
pyplot,
data_binary,
Display,
constructor_name,
):
"""Check that passing `name` when calling `plot` will overwrite the original name
in the legend."""
X, y = data_binary
clf_name = "my hand-crafted name"
clf = LogisticRegression().fit(X, y)
y_pred = clf.predict_proba(X)[:, 1]
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
if constructor_name == "from_estimator":
disp = Display.from_estimator(clf, X, y, name=clf_name)
else:
disp = Display.from_predictions(y, y_pred, name=clf_name)
assert disp.estimator_name == clf_name
pyplot.close("all")
disp.plot()
assert clf_name in disp.line_.get_label()
pyplot.close("all")
clf_name = "another_name"
disp.plot(name=clf_name)
assert clf_name in disp.line_.get_label()
# TODO: remove this test once classes moved to using `name` instead of
# `estimator_name`
@pytest.mark.parametrize(
"clf",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
),
],
)
@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay])
def test_display_curve_not_fitted_errors_old_name(pyplot, data_binary, clf, Display):
"""Check that a proper error is raised when the classifier is not
fitted."""
X, y = data_binary
# clone since we parametrize the test and the classifier will be fitted
# when testing the second and subsequent plotting function
model = clone(clf)
with pytest.raises(NotFittedError):
Display.from_estimator(model, X, y)
model.fit(X, y)
disp = Display.from_estimator(model, X, y)
assert model.__class__.__name__ in disp.line_.get_label()
assert disp.estimator_name == model.__class__.__name__
@pytest.mark.parametrize(
"clf",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
),
],
)
@pytest.mark.parametrize("Display", [RocCurveDisplay])
def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display):
"""Check that a proper error is raised when the classifier is not fitted."""
X, y = data_binary
# clone since we parametrize the test and the classifier will be fitted
# when testing the second and subsequent plotting function
model = clone(clf)
with pytest.raises(NotFittedError):
Display.from_estimator(model, X, y)
model.fit(X, y)
disp = Display.from_estimator(model, X, y)
assert model.__class__.__name__ in disp.line_.get_label()
assert disp.name == model.__class__.__name__
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_n_samples_consistency(pyplot, data_binary, Display):
"""Check the error raised when `y_pred` or `sample_weight` have inconsistent
length."""
X, y = data_binary
classifier = DecisionTreeClassifier().fit(X, y)
msg = "Found input variables with inconsistent numbers of samples"
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X[:-2], y)
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y[:-2])
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y, sample_weight=np.ones(X.shape[0] - 2))
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_error_pos_label(pyplot, data_binary, Display):
"""Check consistence of error message when `pos_label` should be specified."""
X, y = data_binary
y = y + 10
classifier = DecisionTreeClassifier().fit(X, y)
y_pred = classifier.predict_proba(X)[:, -1]
msg = r"y_true takes value in {10, 11} and pos_label is not specified"
with pytest.raises(ValueError, match=msg):
Display.from_predictions(y, y_pred)
@pytest.mark.parametrize(
"Display",
[
CalibrationDisplay,
DetCurveDisplay,
PrecisionRecallDisplay,
RocCurveDisplay,
PredictionErrorDisplay,
ConfusionMatrixDisplay,
],
)
@pytest.mark.parametrize(
"constructor",
["from_predictions", "from_estimator"],
)
def test_classifier_display_curve_named_constructor_return_type(
pyplot, data_binary, Display, constructor
):
"""Check that named constructors return the correct type when subclassed.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/pull/27675
"""
X, y = data_binary
# This can be anything - we just need to check the named constructor return
# type so the only requirement here is instantiating the class without error
y_pred = y
classifier = LogisticRegression().fit(X, y)
class SubclassOfDisplay(Display):
pass
if constructor == "from_predictions":
curve = SubclassOfDisplay.from_predictions(y, y_pred)
else: # constructor == "from_estimator"
curve = SubclassOfDisplay.from_estimator(classifier, X, y)
assert isinstance(curve, SubclassOfDisplay)

View File

@@ -0,0 +1,374 @@
import numpy as np
import pytest
from numpy.testing import (
assert_allclose,
assert_array_equal,
)
from sklearn.compose import make_column_transformer
from sklearn.datasets import make_classification
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR
def test_confusion_matrix_display_validation(pyplot):
"""Check that we raise the proper error when validating parameters."""
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=5, random_state=0
)
with pytest.raises(NotFittedError):
ConfusionMatrixDisplay.from_estimator(SVC(), X, y)
regressor = SVR().fit(X, y)
y_pred_regressor = regressor.predict(X)
y_pred_classifier = SVC().fit(X, y).predict(X)
err_msg = "ConfusionMatrixDisplay.from_estimator only supports classifiers"
with pytest.raises(ValueError, match=err_msg):
ConfusionMatrixDisplay.from_estimator(regressor, X, y)
err_msg = "Mix type of y not allowed, got types"
with pytest.raises(ValueError, match=err_msg):
# Force `y_true` to be seen as a regression problem
ConfusionMatrixDisplay.from_predictions(y + 0.5, y_pred_classifier)
with pytest.raises(ValueError, match=err_msg):
ConfusionMatrixDisplay.from_predictions(y, y_pred_regressor)
err_msg = "Found input variables with inconsistent numbers of samples"
with pytest.raises(ValueError, match=err_msg):
ConfusionMatrixDisplay.from_predictions(y, y_pred_classifier[::2])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("with_labels", [True, False])
@pytest.mark.parametrize("with_display_labels", [True, False])
def test_confusion_matrix_display_custom_labels(
pyplot, constructor_name, with_labels, with_display_labels
):
"""Check the resulting plot when labels are given."""
n_classes = 5
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
)
classifier = SVC().fit(X, y)
y_pred = classifier.predict(X)
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
ax = pyplot.gca()
labels = [2, 1, 0, 3, 4] if with_labels else None
display_labels = ["b", "d", "a", "e", "f"] if with_display_labels else None
cm = confusion_matrix(y, y_pred, labels=labels)
common_kwargs = {
"ax": ax,
"display_labels": display_labels,
"labels": labels,
}
if constructor_name == "from_estimator":
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
else:
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
assert_allclose(disp.confusion_matrix, cm)
if with_display_labels:
expected_display_labels = display_labels
elif with_labels:
expected_display_labels = labels
else:
expected_display_labels = list(range(n_classes))
expected_display_labels_str = [str(name) for name in expected_display_labels]
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
assert_array_equal(disp.display_labels, expected_display_labels)
assert_array_equal(x_ticks, expected_display_labels_str)
assert_array_equal(y_ticks, expected_display_labels_str)
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("normalize", ["true", "pred", "all", None])
@pytest.mark.parametrize("include_values", [True, False])
def test_confusion_matrix_display_plotting(
pyplot,
constructor_name,
normalize,
include_values,
):
"""Check the overall plotting rendering."""
n_classes = 5
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
)
classifier = SVC().fit(X, y)
y_pred = classifier.predict(X)
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
ax = pyplot.gca()
cmap = "plasma"
cm = confusion_matrix(y, y_pred)
common_kwargs = {
"normalize": normalize,
"cmap": cmap,
"ax": ax,
"include_values": include_values,
}
if constructor_name == "from_estimator":
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
else:
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
assert disp.ax_ == ax
if normalize == "true":
cm = cm / cm.sum(axis=1, keepdims=True)
elif normalize == "pred":
cm = cm / cm.sum(axis=0, keepdims=True)
elif normalize == "all":
cm = cm / cm.sum()
assert_allclose(disp.confusion_matrix, cm)
import matplotlib as mpl
assert isinstance(disp.im_, mpl.image.AxesImage)
assert disp.im_.get_cmap().name == cmap
assert isinstance(disp.ax_, pyplot.Axes)
assert isinstance(disp.figure_, pyplot.Figure)
assert disp.ax_.get_ylabel() == "True label"
assert disp.ax_.get_xlabel() == "Predicted label"
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
expected_display_labels = list(range(n_classes))
expected_display_labels_str = [str(name) for name in expected_display_labels]
assert_array_equal(disp.display_labels, expected_display_labels)
assert_array_equal(x_ticks, expected_display_labels_str)
assert_array_equal(y_ticks, expected_display_labels_str)
image_data = disp.im_.get_array().data
assert_allclose(image_data, cm)
if include_values:
assert disp.text_.shape == (n_classes, n_classes)
fmt = ".2g"
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
assert_array_equal(expected_text, text_text)
else:
assert disp.text_ is None
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_confusion_matrix_display(pyplot, constructor_name):
"""Check the behaviour of the default constructor without using the class
methods."""
n_classes = 5
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
)
classifier = SVC().fit(X, y)
y_pred = classifier.predict(X)
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
cm = confusion_matrix(y, y_pred)
common_kwargs = {
"normalize": None,
"include_values": True,
"cmap": "viridis",
"xticks_rotation": 45.0,
}
if constructor_name == "from_estimator":
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
else:
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
assert_allclose(disp.confusion_matrix, cm)
assert disp.text_.shape == (n_classes, n_classes)
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
assert_allclose(rotations, 45.0)
image_data = disp.im_.get_array().data
assert_allclose(image_data, cm)
disp.plot(cmap="plasma")
assert disp.im_.get_cmap().name == "plasma"
disp.plot(include_values=False)
assert disp.text_ is None
disp.plot(xticks_rotation=90.0)
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
assert_allclose(rotations, 90.0)
disp.plot(values_format="e")
expected_text = np.array([format(v, "e") for v in cm.ravel(order="C")])
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
assert_array_equal(expected_text, text_text)
def test_confusion_matrix_contrast(pyplot):
"""Check that the text color is appropriate depending on background."""
cm = np.eye(2) / 2
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
disp.plot(cmap=pyplot.cm.gray)
# diagonal text is black
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
# off-diagonal text is white
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
disp.plot(cmap=pyplot.cm.gray_r)
# diagonal text is white
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
# off-diagonal text is black
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
# Regression test for #15920
cm = np.array([[19, 34], [32, 58]])
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
disp.plot(cmap=pyplot.cm.Blues)
min_color = pyplot.cm.Blues(0)
max_color = pyplot.cm.Blues(255)
assert_allclose(disp.text_[0, 0].get_color(), max_color)
assert_allclose(disp.text_[0, 1].get_color(), max_color)
assert_allclose(disp.text_[1, 0].get_color(), max_color)
assert_allclose(disp.text_[1, 1].get_color(), min_color)
@pytest.mark.parametrize(
"clf",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])),
LogisticRegression(),
),
],
ids=["clf", "pipeline-clf", "pipeline-column_transformer-clf"],
)
def test_confusion_matrix_pipeline(pyplot, clf):
"""Check the behaviour of the plotting with more complex pipeline."""
n_classes = 5
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
)
with pytest.raises(NotFittedError):
ConfusionMatrixDisplay.from_estimator(clf, X, y)
clf.fit(X, y)
y_pred = clf.predict(X)
disp = ConfusionMatrixDisplay.from_estimator(clf, X, y)
cm = confusion_matrix(y, y_pred)
assert_allclose(disp.confusion_matrix, cm)
assert disp.text_.shape == (n_classes, n_classes)
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name):
"""Check that when labels=None, the unique values in `y_pred` and `y_true`
will be used.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/pull/18405
"""
n_classes = 5
X, y = make_classification(
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
)
classifier = SVC().fit(X, y)
y_pred = classifier.predict(X)
# create unseen labels in `y_true` not seen during fitting and not present
# in 'classifier.classes_'
y = y + 1
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
common_kwargs = {"labels": None}
if constructor_name == "from_estimator":
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
else:
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
expected_labels = [str(i) for i in range(n_classes + 1)]
assert_array_equal(expected_labels, display_labels)
def test_colormap_max(pyplot):
"""Check that the max color is used for the color of the text."""
gray = pyplot.get_cmap("gray", 1024)
confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
disp = ConfusionMatrixDisplay(confusion_matrix)
disp.plot(cmap=gray)
color = disp.text_[1, 0].get_color()
assert_allclose(color, [1.0, 1.0, 1.0, 1.0])
def test_im_kw_adjust_vmin_vmax(pyplot):
"""Check that im_kw passes kwargs to imshow"""
confusion_matrix = np.array([[0.48, 0.04], [0.08, 0.4]])
disp = ConfusionMatrixDisplay(confusion_matrix)
disp.plot(im_kw=dict(vmin=0.0, vmax=0.8))
clim = disp.im_.get_clim()
assert clim[0] == pytest.approx(0.0)
assert clim[1] == pytest.approx(0.8)
def test_confusion_matrix_text_kw(pyplot):
"""Check that text_kw is passed to the text call."""
font_size = 15.0
X, y = make_classification(random_state=0)
classifier = SVC().fit(X, y)
# from_estimator passes the font size
disp = ConfusionMatrixDisplay.from_estimator(
classifier, X, y, text_kw={"fontsize": font_size}
)
for text in disp.text_.reshape(-1):
assert text.get_fontsize() == font_size
# plot adjusts plot to new font size
new_font_size = 20.0
disp.plot(text_kw={"fontsize": new_font_size})
for text in disp.text_.reshape(-1):
assert text.get_fontsize() == new_font_size
# from_predictions passes the font size
y_pred = classifier.predict(X)
disp = ConfusionMatrixDisplay.from_predictions(
y, y_pred, text_kw={"fontsize": font_size}
)
for text in disp.text_.reshape(-1):
assert text.get_fontsize() == font_size

View File

@@ -0,0 +1,114 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import DetCurveDisplay, det_curve
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
def test_det_curve_display(
pyplot,
constructor_name,
response_method,
with_sample_weight,
drop_intermediate,
with_strings,
):
X, y = load_iris(return_X_y=True)
# Binarize the data with only the two first classes
X, y = X[y < 2], y[y < 2]
pos_label = None
if with_strings:
y = np.array(["c", "b"])[y]
pos_label = "c"
if with_sample_weight:
rng = np.random.RandomState(42)
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
else:
sample_weight = None
lr = LogisticRegression()
lr.fit(X, y)
y_pred = getattr(lr, response_method)(X)
if y_pred.ndim == 2:
y_pred = y_pred[:, 1]
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
common_kwargs = {
"name": lr.__class__.__name__,
"alpha": 0.8,
"sample_weight": sample_weight,
"drop_intermediate": drop_intermediate,
"pos_label": pos_label,
}
if constructor_name == "from_estimator":
disp = DetCurveDisplay.from_estimator(lr, X, y, **common_kwargs)
else:
disp = DetCurveDisplay.from_predictions(y, y_pred, **common_kwargs)
fpr, fnr, _ = det_curve(
y,
y_pred,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
)
assert_allclose(disp.fpr, fpr, atol=1e-7)
assert_allclose(disp.fnr, fnr, atol=1e-7)
assert disp.estimator_name == "LogisticRegression"
# cannot fail thanks to pyplot fixture
import matplotlib as mpl
assert isinstance(disp.line_, mpl.lines.Line2D)
assert disp.line_.get_alpha() == 0.8
assert isinstance(disp.ax_, mpl.axes.Axes)
assert isinstance(disp.figure_, mpl.figure.Figure)
assert disp.line_.get_label() == "LogisticRegression"
expected_pos_label = 1 if pos_label is None else pos_label
expected_ylabel = f"False Negative Rate (Positive label: {expected_pos_label})"
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
assert disp.ax_.get_ylabel() == expected_ylabel
assert disp.ax_.get_xlabel() == expected_xlabel
@pytest.mark.parametrize(
"constructor_name, expected_clf_name",
[
("from_estimator", "LogisticRegression"),
("from_predictions", "Classifier"),
],
)
def test_det_curve_display_default_name(
pyplot,
constructor_name,
expected_clf_name,
):
# Check the default name display in the figure when `name` is not provided
X, y = load_iris(return_X_y=True)
# Binarize the data with only the two first classes
X, y = X[y < 2], y[y < 2]
lr = LogisticRegression().fit(X, y)
y_pred = lr.predict_proba(X)[:, 1]
if constructor_name == "from_estimator":
disp = DetCurveDisplay.from_estimator(lr, X, y)
else:
disp = DetCurveDisplay.from_predictions(y, y_pred)
assert disp.estimator_name == expected_clf_name
assert disp.line_.get_label() == expected_clf_name

View File

@@ -0,0 +1,382 @@
from collections import Counter
import numpy as np
import pytest
from scipy.integrate import trapezoid
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_breast_cancer, make_classification
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
PrecisionRecallDisplay,
average_precision_score,
precision_recall_curve,
)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("drop_intermediate", [True, False])
def test_precision_recall_display_plotting(
pyplot, constructor_name, response_method, drop_intermediate
):
"""Check the overall plotting rendering."""
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
pos_label = 1
classifier = LogisticRegression().fit(X, y)
classifier.fit(X, y)
y_pred = getattr(classifier, response_method)(X)
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, pos_label]
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
classifier,
X,
y,
response_method=response_method,
drop_intermediate=drop_intermediate,
)
else:
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
)
precision, recall, _ = precision_recall_curve(
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
)
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
np.testing.assert_allclose(display.precision, precision)
np.testing.assert_allclose(display.recall, recall)
assert display.average_precision == pytest.approx(average_precision)
import matplotlib as mpl
assert isinstance(display.line_, mpl.lines.Line2D)
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)
assert display.ax_.get_xlabel() == "Recall (Positive label: 1)"
assert display.ax_.get_ylabel() == "Precision (Positive label: 1)"
assert display.ax_.get_adjustable() == "box"
assert display.ax_.get_aspect() in ("equal", 1.0)
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
# plotting passing some new parameters
display.plot(alpha=0.8, name="MySpecialEstimator")
expected_label = f"MySpecialEstimator (AP = {average_precision:0.2f})"
assert display.line_.get_label() == expected_label
assert display.line_.get_alpha() == pytest.approx(0.8)
# Check that the chance level line is not plotted by default
assert display.chance_level_ is None
@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}, {"c": "r"}])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_precision_recall_chance_level_line(
pyplot,
chance_level_kw,
constructor_name,
):
"""Check the chance level line plotting behavior."""
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
pos_prevalence = Counter(y)[1] / len(y)
lr = LogisticRegression()
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
lr,
X,
y,
plot_chance_level=True,
chance_level_kw=chance_level_kw,
)
else:
display = PrecisionRecallDisplay.from_predictions(
y,
y_pred,
plot_chance_level=True,
chance_level_kw=chance_level_kw,
)
import matplotlib as mpl
assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (pos_prevalence, pos_prevalence)
# Checking for chance level line styles
if chance_level_kw is None:
assert display.chance_level_.get_color() == "k"
else:
assert display.chance_level_.get_color() == "r"
@pytest.mark.parametrize(
"constructor_name, default_label",
[
("from_estimator", "LogisticRegression (AP = {:.2f})"),
("from_predictions", "Classifier (AP = {:.2f})"),
],
)
def test_precision_recall_display_name(pyplot, constructor_name, default_label):
"""Check the behaviour of the name parameters"""
X, y = make_classification(n_classes=2, n_samples=100, random_state=0)
pos_label = 1
classifier = LogisticRegression().fit(X, y)
classifier.fit(X, y)
y_pred = classifier.predict_proba(X)[:, pos_label]
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(classifier, X, y)
else:
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, pos_label=pos_label
)
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
# check that the default name is used
assert display.line_.get_label() == default_label.format(average_precision)
# check that the name can be set
display.plot(name="MySpecialEstimator")
assert (
display.line_.get_label()
== f"MySpecialEstimator (AP = {average_precision:.2f})"
)
@pytest.mark.parametrize(
"clf",
[
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
),
],
)
def test_precision_recall_display_pipeline(pyplot, clf):
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
with pytest.raises(NotFittedError):
PrecisionRecallDisplay.from_estimator(clf, X, y)
clf.fit(X, y)
display = PrecisionRecallDisplay.from_estimator(clf, X, y)
assert display.estimator_name == clf.__class__.__name__
def test_precision_recall_display_string_labels(pyplot):
# regression test #15738
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target_names[cancer.target]
lr = make_pipeline(StandardScaler(), LogisticRegression())
lr.fit(X, y)
for klass in cancer.target_names:
assert klass in lr.classes_
display = PrecisionRecallDisplay.from_estimator(lr, X, y)
y_pred = lr.predict_proba(X)[:, 1]
avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1])
assert display.average_precision == pytest.approx(avg_prec)
assert display.estimator_name == lr.__class__.__name__
err_msg = r"y_true takes value in {'benign', 'malignant'}"
with pytest.raises(ValueError, match=err_msg):
PrecisionRecallDisplay.from_predictions(y, y_pred)
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, pos_label=lr.classes_[1]
)
assert display.average_precision == pytest.approx(avg_prec)
@pytest.mark.parametrize(
"average_precision, estimator_name, expected_label",
[
(0.9, None, "AP = 0.90"),
(None, "my_est", "my_est"),
(0.8, "my_est2", "my_est2 (AP = 0.80)"),
],
)
def test_default_labels(pyplot, average_precision, estimator_name, expected_label):
"""Check the default labels used in the display."""
precision = np.array([1, 0.5, 0])
recall = np.array([0, 0.5, 1])
display = PrecisionRecallDisplay(
precision,
recall,
average_precision=average_precision,
estimator_name=estimator_name,
)
display.plot()
assert display.line_.get_label() == expected_label
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_method):
# check that we can provide the positive label and display the proper
# statistics
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced version of the breast cancer dataset
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
stratify=y,
random_state=0,
)
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
# sanity check to be sure the positive class is classes_[0] and that we
# are betrayed by the class imbalance
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
y_pred = getattr(classifier, response_method)(X_test)
# we select the corresponding probability columns or reverse the decision
# function otherwise
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
classifier,
X_test,
y_test,
pos_label="cancer",
response_method=response_method,
)
else:
display = PrecisionRecallDisplay.from_predictions(
y_test,
y_pred_cancer,
pos_label="cancer",
)
# we should obtain the statistics of the "cancer" class
avg_prec_limit = 0.65
assert display.average_precision < avg_prec_limit
assert -trapezoid(display.precision, display.recall) < avg_prec_limit
# otherwise we should obtain the statistics of the "not cancer" class
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
classifier,
X_test,
y_test,
response_method=response_method,
pos_label="not cancer",
)
else:
display = PrecisionRecallDisplay.from_predictions(
y_test,
y_pred_not_cancer,
pos_label="not cancer",
)
avg_prec_limit = 0.95
assert display.average_precision > avg_prec_limit
assert -trapezoid(display.precision, display.recall) > avg_prec_limit
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name):
# Check that even if one passes plot_chance_level=False the first time
# one can still call disp.plot with plot_chance_level=True and get the
# chance level line
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
lr = LogisticRegression()
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
lr, X, y, plot_chance_level=False
)
else:
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, plot_chance_level=False
)
assert display.chance_level_ is None
import matplotlib as mpl
# When calling from_estimator or from_predictions,
# prevalence_pos_label should have been set, so that directly
# calling plot_chance_level=True should plot the chance level line
display.plot(plot_chance_level=True)
assert isinstance(display.chance_level_, mpl.lines.Line2D)
def test_precision_recall_raise_no_prevalence(pyplot):
# Check that raises correctly when plotting chance level with
# no prvelance_pos_label is provided
precision = np.array([1, 0.5, 0])
recall = np.array([0, 0.5, 1])
display = PrecisionRecallDisplay(precision, recall)
msg = (
"You must provide prevalence_pos_label when constructing the "
"PrecisionRecallDisplay object in order to plot the chance "
"level line. Alternatively, you may use "
"PrecisionRecallDisplay.from_estimator or "
"PrecisionRecallDisplay.from_predictions "
"to automatically set prevalence_pos_label"
)
with pytest.raises(ValueError, match=msg):
display.plot(plot_chance_level=True)
@pytest.mark.parametrize("despine", [True, False])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_plot_precision_recall_despine(pyplot, despine, constructor_name):
# Check that the despine keyword is working correctly
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
clf = LogisticRegression().fit(X, y)
clf.fit(X, y)
y_pred = clf.decision_function(X)
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(clf, X, y, despine=despine)
else:
display = PrecisionRecallDisplay.from_predictions(y, y_pred, despine=despine)
for s in ["top", "right"]:
assert display.ax_.spines[s].get_visible() is not despine
if despine:
for s in ["bottom", "left"]:
assert display.ax_.spines[s].get_bounds() == (0, 1)

View File

@@ -0,0 +1,169 @@
import pytest
from numpy.testing import assert_allclose
from sklearn.datasets import load_diabetes
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import Ridge
from sklearn.metrics import PredictionErrorDisplay
X, y = load_diabetes(return_X_y=True)
@pytest.fixture
def regressor_fitted():
return Ridge().fit(X, y)
@pytest.mark.parametrize(
"regressor, params, err_type, err_msg",
[
(
Ridge().fit(X, y),
{"subsample": -1},
ValueError,
"When an integer, subsample=-1 should be",
),
(
Ridge().fit(X, y),
{"subsample": 20.0},
ValueError,
"When a floating-point, subsample=20.0 should be",
),
(
Ridge().fit(X, y),
{"subsample": -20.0},
ValueError,
"When a floating-point, subsample=-20.0 should be",
),
(
Ridge().fit(X, y),
{"kind": "xxx"},
ValueError,
"`kind` must be one of",
),
],
)
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
def test_prediction_error_display_raise_error(
pyplot, class_method, regressor, params, err_type, err_msg
):
"""Check that we raise the proper error when making the parameters
# validation."""
with pytest.raises(err_type, match=err_msg):
if class_method == "from_estimator":
PredictionErrorDisplay.from_estimator(regressor, X, y, **params)
else:
y_pred = regressor.predict(X)
PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred, **params)
def test_from_estimator_not_fitted(pyplot):
"""Check that we raise a `NotFittedError` when the passed regressor is not
fit."""
regressor = Ridge()
with pytest.raises(NotFittedError, match="is not fitted yet."):
PredictionErrorDisplay.from_estimator(regressor, X, y)
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("kind", ["actual_vs_predicted", "residual_vs_predicted"])
def test_prediction_error_display(pyplot, regressor_fitted, class_method, kind):
"""Check the default behaviour of the display."""
if class_method == "from_estimator":
display = PredictionErrorDisplay.from_estimator(
regressor_fitted, X, y, kind=kind
)
else:
y_pred = regressor_fitted.predict(X)
display = PredictionErrorDisplay.from_predictions(
y_true=y, y_pred=y_pred, kind=kind
)
if kind == "actual_vs_predicted":
assert_allclose(display.line_.get_xdata(), display.line_.get_ydata())
assert display.ax_.get_xlabel() == "Predicted values"
assert display.ax_.get_ylabel() == "Actual values"
assert display.line_ is not None
else:
assert display.ax_.get_xlabel() == "Predicted values"
assert display.ax_.get_ylabel() == "Residuals (actual - predicted)"
assert display.line_ is not None
assert display.ax_.get_legend() is None
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize(
"subsample, expected_size",
[(5, 5), (0.1, int(X.shape[0] * 0.1)), (None, X.shape[0])],
)
def test_plot_prediction_error_subsample(
pyplot, regressor_fitted, class_method, subsample, expected_size
):
"""Check the behaviour of `subsample`."""
if class_method == "from_estimator":
display = PredictionErrorDisplay.from_estimator(
regressor_fitted, X, y, subsample=subsample
)
else:
y_pred = regressor_fitted.predict(X)
display = PredictionErrorDisplay.from_predictions(
y_true=y, y_pred=y_pred, subsample=subsample
)
assert len(display.scatter_.get_offsets()) == expected_size
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
def test_plot_prediction_error_ax(pyplot, regressor_fitted, class_method):
"""Check that we can pass an axis to the display."""
_, ax = pyplot.subplots()
if class_method == "from_estimator":
display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y, ax=ax)
else:
y_pred = regressor_fitted.predict(X)
display = PredictionErrorDisplay.from_predictions(
y_true=y, y_pred=y_pred, ax=ax
)
assert display.ax_ is ax
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize(
"scatter_kwargs",
[None, {"color": "blue", "alpha": 0.9}, {"c": "blue", "alpha": 0.9}],
)
@pytest.mark.parametrize(
"line_kwargs", [None, {"color": "red", "linestyle": "-"}, {"c": "red", "ls": "-"}]
)
def test_prediction_error_custom_artist(
pyplot, regressor_fitted, class_method, scatter_kwargs, line_kwargs
):
"""Check that we can tune the style of the line and the scatter."""
extra_params = {
"kind": "actual_vs_predicted",
"scatter_kwargs": scatter_kwargs,
"line_kwargs": line_kwargs,
}
if class_method == "from_estimator":
display = PredictionErrorDisplay.from_estimator(
regressor_fitted, X, y, **extra_params
)
else:
y_pred = regressor_fitted.predict(X)
display = PredictionErrorDisplay.from_predictions(
y_true=y, y_pred=y_pred, **extra_params
)
if line_kwargs is not None:
assert display.line_.get_linestyle() == "-"
assert display.line_.get_color() == "red"
else:
assert display.line_.get_linestyle() == "--"
assert display.line_.get_color() == "black"
assert display.line_.get_alpha() == 0.7
if scatter_kwargs is not None:
assert_allclose(display.scatter_.get_facecolor(), [[0.0, 0.0, 1.0, 0.9]])
assert_allclose(display.scatter_.get_edgecolor(), [[0.0, 0.0, 1.0, 0.9]])
else:
assert display.scatter_.get_alpha() == 0.8

View File

@@ -0,0 +1,987 @@
from collections.abc import Mapping
import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy.integrate import trapezoid
from sklearn import clone
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_breast_cancer, make_classification
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import RocCurveDisplay, auc, roc_curve
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils import _safe_indexing, shuffle
from sklearn.utils._response import _get_response_values_binary
@pytest.fixture(scope="module")
def data_binary():
X, y = make_classification(
n_samples=200,
n_features=20,
n_informative=5,
n_redundant=2,
flip_y=0.1,
class_sep=0.8,
random_state=42,
)
return X, y
def _check_figure_axes_and_labels(display, pos_label):
"""Check mpl axes and figure defaults are correct."""
import matplotlib as mpl
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)
assert display.ax_.get_adjustable() == "box"
assert display.ax_.get_aspect() in ("equal", 1.0)
assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01)
expected_pos_label = 1 if pos_label is None else pos_label
expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})"
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
assert display.ax_.get_ylabel() == expected_ylabel
assert display.ax_.get_xlabel() == expected_xlabel
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
@pytest.mark.parametrize(
"constructor_name, default_name",
[
("from_estimator", "LogisticRegression"),
("from_predictions", "Classifier"),
],
)
def test_roc_curve_display_plotting(
pyplot,
response_method,
data_binary,
with_sample_weight,
drop_intermediate,
with_strings,
constructor_name,
default_name,
):
"""Check the overall plotting behaviour for single curve."""
X, y = data_binary
pos_label = None
if with_strings:
y = np.array(["c", "b"])[y]
pos_label = "c"
if with_sample_weight:
rng = np.random.RandomState(42)
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
else:
sample_weight = None
lr = LogisticRegression()
lr.fit(X, y)
y_score = getattr(lr, response_method)(X)
y_score = y_score if y_score.ndim == 1 else y_score[:, 1]
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
lr,
X,
y,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
curve_kwargs={"alpha": 0.8},
)
else:
display = RocCurveDisplay.from_predictions(
y,
y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
curve_kwargs={"alpha": 0.8},
)
fpr, tpr, _ = roc_curve(
y,
y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
)
assert_allclose(display.roc_auc, auc(fpr, tpr))
assert_allclose(display.fpr, fpr)
assert_allclose(display.tpr, tpr)
assert display.name == default_name
import matplotlib as mpl
_check_figure_axes_and_labels(display, pos_label)
assert isinstance(display.line_, mpl.lines.Line2D)
assert display.line_.get_alpha() == 0.8
expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})"
assert display.line_.get_label() == expected_label
@pytest.mark.parametrize(
"params, err_msg",
[
(
{
"fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"tpr": [np.array([0, 0.5, 1])],
"roc_auc": None,
"name": None,
},
"self.fpr and self.tpr from `RocCurveDisplay` initialization,",
),
(
{
"fpr": [np.array([0, 0.5, 1])],
"tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"roc_auc": [0.8, 0.9],
"name": None,
},
"self.fpr, self.tpr and self.roc_auc from `RocCurveDisplay`",
),
(
{
"fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"roc_auc": [0.8],
"name": None,
},
"Got: self.fpr: 2, self.tpr: 2, self.roc_auc: 1",
),
(
{
"fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"roc_auc": [0.8, 0.9],
"name": ["curve1", "curve2", "curve3"],
},
r"self.fpr, self.tpr, self.roc_auc and 'name' \(or self.name\)",
),
(
{
"fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])],
"roc_auc": [0.8, 0.9],
# List of length 1 is always allowed
"name": ["curve1"],
},
None,
),
],
)
def test_roc_curve_plot_parameter_length_validation(pyplot, params, err_msg):
"""Check `plot` parameter length validation performed correctly."""
display = RocCurveDisplay(**params)
if err_msg:
with pytest.raises(ValueError, match=err_msg):
display.plot()
else:
# No error should be raised
display.plot()
def test_validate_plot_params(pyplot):
"""Check `_validate_plot_params` returns the correct variables."""
fpr = np.array([0, 0.5, 1])
tpr = [np.array([0, 0.5, 1])]
roc_auc = None
name = "test_curve"
# Initialize display with test inputs
display = RocCurveDisplay(
fpr=fpr,
tpr=tpr,
roc_auc=roc_auc,
name=name,
pos_label=None,
)
fpr_out, tpr_out, roc_auc_out, name_out = display._validate_plot_params(
ax=None, name=None
)
assert isinstance(fpr_out, list)
assert isinstance(tpr_out, list)
assert len(fpr_out) == 1
assert len(tpr_out) == 1
assert roc_auc_out is None
assert name_out == ["test_curve"]
def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary):
"""Check parameter validation is correct."""
X, y = data_binary
# `cv_results` missing key
cv_results_no_est = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False
)
cv_results_no_indices = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False
)
for cv_results in (cv_results_no_est, cv_results_no_indices):
with pytest.raises(
ValueError,
match="`cv_results` does not contain one of the following required",
):
RocCurveDisplay.from_cv_results(cv_results, X, y)
cv_results = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
)
# `X` wrong length
with pytest.raises(ValueError, match="`X` does not contain the correct"):
RocCurveDisplay.from_cv_results(cv_results, X[:10, :], y)
# `y` not binary
y_multi = y.copy()
y_multi[0] = 2
with pytest.raises(ValueError, match="The target `y` is not binary."):
RocCurveDisplay.from_cv_results(cv_results, X, y_multi)
# input inconsistent length
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
RocCurveDisplay.from_cv_results(cv_results, X, y[:10])
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
RocCurveDisplay.from_cv_results(cv_results, X, y, sample_weight=[1, 2])
# `pos_label` inconsistency
y_multi[y_multi == 1] = 2
with pytest.raises(ValueError, match=r"y takes value in \{0, 2\}"):
RocCurveDisplay.from_cv_results(cv_results, X, y_multi)
# `name` is list while `curve_kwargs` is None or dict
for curve_kwargs in (None, {"alpha": 0.2}):
with pytest.raises(ValueError, match="To avoid labeling individual curves"):
RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
name=["one", "two", "three"],
curve_kwargs=curve_kwargs,
)
# `curve_kwargs` incorrect length
with pytest.raises(ValueError, match="`curve_kwargs` must be None, a dictionary"):
RocCurveDisplay.from_cv_results(cv_results, X, y, curve_kwargs=[{"alpha": 1}])
# `curve_kwargs` both alias provided
with pytest.raises(TypeError, match="Got both c and"):
RocCurveDisplay.from_cv_results(
cv_results, X, y, curve_kwargs={"c": "blue", "color": "red"}
)
@pytest.mark.parametrize(
"curve_kwargs",
[None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]],
)
def test_roc_curve_display_from_cv_results_curve_kwargs(
pyplot, data_binary, curve_kwargs
):
"""Check `curve_kwargs` correctly passed."""
X, y = data_binary
n_cv = 3
cv_results = cross_validate(
LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True
)
display = RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
curve_kwargs=curve_kwargs,
)
if curve_kwargs is None:
# Default `alpha` used
assert all(line.get_alpha() == 0.5 for line in display.line_)
elif isinstance(curve_kwargs, Mapping):
# `alpha` from dict used for all curves
assert all(line.get_alpha() == 0.2 for line in display.line_)
else:
# Different `alpha` used for each curve
assert all(
line.get_alpha() == curve_kwargs[i]["alpha"]
for i, line in enumerate(display.line_)
)
# TODO(1.9): Remove in 1.9
def test_roc_curve_display_estimator_name_deprecation(pyplot):
"""Check deprecation of `estimator_name`."""
fpr = np.array([0, 0.5, 1])
tpr = np.array([0, 0.5, 1])
with pytest.warns(FutureWarning, match="`estimator_name` is deprecated in"):
RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test")
# TODO(1.9): Remove in 1.9
@pytest.mark.parametrize(
"constructor_name", ["from_estimator", "from_predictions", "plot"]
)
def test_roc_curve_display_kwargs_deprecation(pyplot, data_binary, constructor_name):
"""Check **kwargs deprecated correctly in favour of `curve_kwargs`."""
X, y = data_binary
lr = LogisticRegression()
lr.fit(X, y)
fpr = np.array([0, 0.5, 1])
tpr = np.array([0, 0.5, 1])
# Error when both `curve_kwargs` and `**kwargs` provided
with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"):
if constructor_name == "from_estimator":
RocCurveDisplay.from_estimator(
lr, X, y, curve_kwargs={"alpha": 1}, label="test"
)
elif constructor_name == "from_predictions":
RocCurveDisplay.from_predictions(
y, y, curve_kwargs={"alpha": 1}, label="test"
)
else:
RocCurveDisplay(fpr=fpr, tpr=tpr).plot(
curve_kwargs={"alpha": 1}, label="test"
)
# Warning when `**kwargs`` provided
with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and will be"):
if constructor_name == "from_estimator":
RocCurveDisplay.from_estimator(lr, X, y, label="test")
elif constructor_name == "from_predictions":
RocCurveDisplay.from_predictions(y, y, label="test")
else:
RocCurveDisplay(fpr=fpr, tpr=tpr).plot(label="test")
@pytest.mark.parametrize(
"curve_kwargs",
[
None,
{"color": "blue"},
[{"color": "blue"}, {"color": "green"}, {"color": "red"}],
],
)
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
def test_roc_curve_display_plotting_from_cv_results(
pyplot,
data_binary,
with_strings,
with_sample_weight,
response_method,
drop_intermediate,
curve_kwargs,
):
"""Check overall plotting of `from_cv_results`."""
X, y = data_binary
pos_label = None
if with_strings:
y = np.array(["c", "b"])[y]
pos_label = "c"
if with_sample_weight:
rng = np.random.RandomState(42)
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
else:
sample_weight = None
cv_results = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
)
display = RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
response_method=response_method,
pos_label=pos_label,
curve_kwargs=curve_kwargs,
)
for idx, (estimator, test_indices) in enumerate(
zip(cv_results["estimator"], cv_results["indices"]["test"])
):
y_true = _safe_indexing(y, test_indices)
y_pred = _get_response_values_binary(
estimator,
_safe_indexing(X, test_indices),
response_method=response_method,
pos_label=pos_label,
)[0]
sample_weight_fold = (
None
if sample_weight is None
else _safe_indexing(sample_weight, test_indices)
)
fpr, tpr, _ = roc_curve(
y_true,
y_pred,
sample_weight=sample_weight_fold,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
)
assert_allclose(display.roc_auc[idx], auc(fpr, tpr))
assert_allclose(display.fpr[idx], fpr)
assert_allclose(display.tpr[idx], tpr)
assert display.name is None
import matplotlib as mpl
_check_figure_axes_and_labels(display, pos_label)
if with_sample_weight:
aggregate_expected_labels = ["AUC = 0.64 +/- 0.04", "_child1", "_child2"]
else:
aggregate_expected_labels = ["AUC = 0.61 +/- 0.05", "_child1", "_child2"]
for idx, line in enumerate(display.line_):
assert isinstance(line, mpl.lines.Line2D)
# Default alpha for `from_cv_results`
line.get_alpha() == 0.5
if isinstance(curve_kwargs, list):
# Each individual curve labelled
assert line.get_label() == f"AUC = {display.roc_auc[idx]:.2f}"
else:
# Single aggregate label
assert line.get_label() == aggregate_expected_labels[idx]
@pytest.mark.parametrize("roc_auc", [[1.0, 1.0, 1.0], None])
@pytest.mark.parametrize(
"curve_kwargs",
[None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]],
)
@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]])
def test_roc_curve_plot_legend_label(pyplot, data_binary, name, curve_kwargs, roc_auc):
"""Check legend label correct with all `curve_kwargs`, `name` combinations."""
fpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])]
tpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])]
if not isinstance(curve_kwargs, list) and isinstance(name, list):
with pytest.raises(ValueError, match="To avoid labeling individual curves"):
RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot(
name=name, curve_kwargs=curve_kwargs
)
else:
display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot(
name=name, curve_kwargs=curve_kwargs
)
legend = display.ax_.get_legend()
if legend is None:
# No legend is created, exit test early
assert name is None
assert roc_auc is None
return
else:
legend_labels = [text.get_text() for text in legend.get_texts()]
if isinstance(curve_kwargs, list):
# Multiple labels in legend
assert len(legend_labels) == 3
for idx, label in enumerate(legend_labels):
if name is None:
expected_label = "AUC = 1.00" if roc_auc else None
assert label == expected_label
elif isinstance(name, str):
expected_label = "single (AUC = 1.00)" if roc_auc else "single"
assert label == expected_label
else:
# `name` is a list of different strings
expected_label = (
f"{name[idx]} (AUC = 1.00)" if roc_auc else f"{name[idx]}"
)
assert label == expected_label
else:
# Single label in legend
assert len(legend_labels) == 1
if name is None:
expected_label = "AUC = 1.00 +/- 0.00" if roc_auc else None
assert legend_labels[0] == expected_label
else:
# name is single string
expected_label = "single (AUC = 1.00 +/- 0.00)" if roc_auc else "single"
assert legend_labels[0] == expected_label
@pytest.mark.parametrize(
"curve_kwargs",
[None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]],
)
@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]])
def test_roc_curve_from_cv_results_legend_label(
pyplot, data_binary, name, curve_kwargs
):
"""Check legend label correct with all `curve_kwargs`, `name` combinations."""
X, y = data_binary
n_cv = 3
cv_results = cross_validate(
LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True
)
if not isinstance(curve_kwargs, list) and isinstance(name, list):
with pytest.raises(ValueError, match="To avoid labeling individual curves"):
RocCurveDisplay.from_cv_results(
cv_results, X, y, name=name, curve_kwargs=curve_kwargs
)
else:
display = RocCurveDisplay.from_cv_results(
cv_results, X, y, name=name, curve_kwargs=curve_kwargs
)
legend = display.ax_.get_legend()
legend_labels = [text.get_text() for text in legend.get_texts()]
if isinstance(curve_kwargs, list):
# Multiple labels in legend
assert len(legend_labels) == 3
auc = ["0.62", "0.66", "0.55"]
for idx, label in enumerate(legend_labels):
if name is None:
assert label == f"AUC = {auc[idx]}"
elif isinstance(name, str):
assert label == f"single (AUC = {auc[idx]})"
else:
# `name` is a list of different strings
assert label == f"{name[idx]} (AUC = {auc[idx]})"
else:
# Single label in legend
assert len(legend_labels) == 1
if name is None:
assert legend_labels[0] == "AUC = 0.61 +/- 0.05"
else:
# name is single string
assert legend_labels[0] == "single (AUC = 0.61 +/- 0.05)"
@pytest.mark.parametrize(
"curve_kwargs",
[None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]],
)
def test_roc_curve_from_cv_results_curve_kwargs(pyplot, data_binary, curve_kwargs):
"""Check line kwargs passed correctly in `from_cv_results`."""
X, y = data_binary
cv_results = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
)
display = RocCurveDisplay.from_cv_results(
cv_results, X, y, curve_kwargs=curve_kwargs
)
for idx, line in enumerate(display.line_):
color = line.get_color()
if curve_kwargs is None:
# Default color
assert color == "blue"
elif isinstance(curve_kwargs, Mapping):
# All curves "red"
assert color == "red"
else:
assert color == curve_kwargs[idx]["c"]
def _check_chance_level(plot_chance_level, chance_level_kw, display):
"""Check chance level line and line styles correct."""
import matplotlib as mpl
if plot_chance_level:
assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (0, 1)
else:
assert display.chance_level_ is None
# Checking for chance level line styles
if plot_chance_level and chance_level_kw is None:
assert display.chance_level_.get_color() == "k"
assert display.chance_level_.get_linestyle() == "--"
assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)"
elif plot_chance_level:
if "c" in chance_level_kw:
assert display.chance_level_.get_color() == chance_level_kw["c"]
else:
assert display.chance_level_.get_color() == chance_level_kw["color"]
if "lw" in chance_level_kw:
assert display.chance_level_.get_linewidth() == chance_level_kw["lw"]
else:
assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"]
if "ls" in chance_level_kw:
assert display.chance_level_.get_linestyle() == chance_level_kw["ls"]
else:
assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"]
@pytest.mark.parametrize("plot_chance_level", [True, False])
@pytest.mark.parametrize("label", [None, "Test Label"])
@pytest.mark.parametrize(
"chance_level_kw",
[
None,
{"linewidth": 1, "color": "red", "linestyle": "-", "label": "DummyEstimator"},
{"lw": 1, "c": "red", "ls": "-", "label": "DummyEstimator"},
{"lw": 1, "color": "blue", "ls": "-", "label": None},
],
)
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_roc_curve_chance_level_line(
pyplot,
data_binary,
plot_chance_level,
chance_level_kw,
label,
constructor_name,
):
"""Check chance level plotting behavior of `from_predictions`, `from_estimator`."""
X, y = data_binary
lr = LogisticRegression()
lr.fit(X, y)
y_score = getattr(lr, "predict_proba")(X)
y_score = y_score if y_score.ndim == 1 else y_score[:, 1]
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
lr,
X,
y,
curve_kwargs={"alpha": 0.8, "label": label},
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
)
else:
display = RocCurveDisplay.from_predictions(
y,
y_score,
curve_kwargs={"alpha": 0.8, "label": label},
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
)
import matplotlib as mpl
assert isinstance(display.line_, mpl.lines.Line2D)
assert display.line_.get_alpha() == 0.8
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)
_check_chance_level(plot_chance_level, chance_level_kw, display)
# Checking for legend behaviour
if plot_chance_level and chance_level_kw is not None:
if label is not None or chance_level_kw.get("label") is not None:
legend = display.ax_.get_legend()
assert legend is not None # Legend should be present if any label is set
legend_labels = [text.get_text() for text in legend.get_texts()]
if label is not None:
assert label in legend_labels
if chance_level_kw.get("label") is not None:
assert chance_level_kw["label"] in legend_labels
else:
assert display.ax_.get_legend() is None
@pytest.mark.parametrize("plot_chance_level", [True, False])
@pytest.mark.parametrize(
"chance_level_kw",
[
None,
{"linewidth": 1, "color": "red", "linestyle": "-", "label": "DummyEstimator"},
{"lw": 1, "c": "red", "ls": "-", "label": "DummyEstimator"},
{"lw": 1, "color": "blue", "ls": "-", "label": None},
],
)
@pytest.mark.parametrize("curve_kwargs", [None, {"alpha": 0.8}])
def test_roc_curve_chance_level_line_from_cv_results(
pyplot,
data_binary,
plot_chance_level,
chance_level_kw,
curve_kwargs,
):
"""Check chance level plotting behavior with `from_cv_results`."""
X, y = data_binary
n_cv = 3
cv_results = cross_validate(
LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True
)
display = RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
plot_chance_level=plot_chance_level,
chance_level_kwargs=chance_level_kw,
curve_kwargs=curve_kwargs,
)
import matplotlib as mpl
assert all(isinstance(line, mpl.lines.Line2D) for line in display.line_)
# Ensure both curve line kwargs passed correctly as well
if curve_kwargs:
assert all(line.get_alpha() == 0.8 for line in display.line_)
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)
_check_chance_level(plot_chance_level, chance_level_kw, display)
legend = display.ax_.get_legend()
# There is always a legend, to indicate each 'Fold' curve
assert legend is not None
legend_labels = [text.get_text() for text in legend.get_texts()]
if plot_chance_level and chance_level_kw is not None:
if chance_level_kw.get("label") is not None:
assert chance_level_kw["label"] in legend_labels
else:
assert len(legend_labels) == 1
@pytest.mark.parametrize(
"clf",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
),
],
)
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructor_name):
"""Check the behaviour with complex pipeline."""
X, y = data_binary
clf = clone(clf)
if constructor_name == "from_estimator":
with pytest.raises(NotFittedError):
RocCurveDisplay.from_estimator(clf, X, y)
clf.fit(X, y)
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(clf, X, y)
name = clf.__class__.__name__
else:
display = RocCurveDisplay.from_predictions(y, y)
name = "Classifier"
assert name in display.line_.get_label()
assert display.name == name
@pytest.mark.parametrize(
"roc_auc, name, curve_kwargs, expected_labels",
[
([0.9, 0.8], None, None, ["AUC = 0.85 +/- 0.05", "_child1"]),
([0.9, 0.8], "Est name", None, ["Est name (AUC = 0.85 +/- 0.05)", "_child1"]),
(
[0.8, 0.7],
["fold1", "fold2"],
[{"c": "blue"}, {"c": "red"}],
["fold1 (AUC = 0.80)", "fold2 (AUC = 0.70)"],
),
(None, ["fold1", "fold2"], [{"c": "blue"}, {"c": "red"}], ["fold1", "fold2"]),
],
)
def test_roc_curve_display_default_labels(
pyplot, roc_auc, name, curve_kwargs, expected_labels
):
"""Check the default labels used in the display."""
fpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])]
tpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])]
disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot(
curve_kwargs=curve_kwargs
)
for idx, expected_label in enumerate(expected_labels):
assert disp.line_[idx].get_label() == expected_label
def _check_auc(display, constructor_name):
roc_auc_limit = 0.95679
roc_auc_limit_multi = [0.97007, 0.985915, 0.980952]
if constructor_name == "from_cv_results":
for idx, roc_auc in enumerate(display.roc_auc):
assert roc_auc == pytest.approx(roc_auc_limit_multi[idx])
else:
assert display.roc_auc == pytest.approx(roc_auc_limit)
assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize(
"constructor_name", ["from_estimator", "from_predictions", "from_cv_results"]
)
def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
# check that we can provide the positive label and display the proper
# statistics
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
stratify=y,
random_state=0,
)
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
cv_results = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
)
# Sanity check to be sure the positive class is `classes_[0]`
# Class imbalance ensures a large difference in prediction values between classes,
# allowing us to catch errors when we switch `pos_label`
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
y_score = getattr(classifier, response_method)(X_test)
# we select the corresponding probability columns or reverse the decision
# function otherwise
y_score_cancer = -1 * y_score if y_score.ndim == 1 else y_score[:, 0]
y_score_not_cancer = y_score if y_score.ndim == 1 else y_score[:, 1]
pos_label = "cancer"
y_score = y_score_cancer
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
classifier,
X_test,
y_test,
pos_label=pos_label,
response_method=response_method,
)
elif constructor_name == "from_predictions":
display = RocCurveDisplay.from_predictions(
y_test,
y_score,
pos_label=pos_label,
)
else:
display = RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
response_method=response_method,
pos_label=pos_label,
)
_check_auc(display, constructor_name)
pos_label = "not cancer"
y_score = y_score_not_cancer
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
classifier,
X_test,
y_test,
response_method=response_method,
pos_label=pos_label,
)
elif constructor_name == "from_predictions":
display = RocCurveDisplay.from_predictions(
y_test,
y_score,
pos_label=pos_label,
)
else:
display = RocCurveDisplay.from_cv_results(
cv_results,
X,
y,
response_method=response_method,
pos_label=pos_label,
)
_check_auc(display, constructor_name)
# TODO(1.9): remove
def test_y_score_and_y_pred_specified_error():
"""Check that an error is raised when both y_score and y_pred are specified."""
y_true = np.array([0, 1, 1, 0])
y_score = np.array([0.1, 0.4, 0.35, 0.8])
y_pred = np.array([0.2, 0.3, 0.5, 0.1])
with pytest.raises(
ValueError, match="`y_pred` and `y_score` cannot be both specified"
):
RocCurveDisplay.from_predictions(y_true, y_score=y_score, y_pred=y_pred)
# TODO(1.9): remove
def test_y_pred_deprecation_warning(pyplot):
"""Check that a warning is raised when y_pred is specified."""
y_true = np.array([0, 1, 1, 0])
y_score = np.array([0.1, 0.4, 0.35, 0.8])
with pytest.warns(FutureWarning, match="y_pred is deprecated in 1.7"):
display_y_pred = RocCurveDisplay.from_predictions(y_true, y_pred=y_score)
assert_allclose(display_y_pred.fpr, [0, 0.5, 0.5, 1])
assert_allclose(display_y_pred.tpr, [0, 0, 1, 1])
display_y_score = RocCurveDisplay.from_predictions(y_true, y_score)
assert_allclose(display_y_score.fpr, [0, 0.5, 0.5, 1])
assert_allclose(display_y_score.tpr, [0, 0, 1, 1])
@pytest.mark.parametrize("despine", [True, False])
@pytest.mark.parametrize(
"constructor_name", ["from_estimator", "from_predictions", "from_cv_results"]
)
def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name):
# Check that the despine keyword is working correctly
X, y = data_binary
lr = LogisticRegression().fit(X, y)
lr.fit(X, y)
cv_results = cross_validate(
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
)
y_pred = lr.decision_function(X)
# safe guard for the if/else construction
assert constructor_name in ("from_estimator", "from_predictions", "from_cv_results")
if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(lr, X, y, despine=despine)
elif constructor_name == "from_predictions":
display = RocCurveDisplay.from_predictions(y, y_pred, despine=despine)
else:
display = RocCurveDisplay.from_cv_results(cv_results, X, y, despine=despine)
for s in ["top", "right"]:
assert display.ax_.spines[s].get_visible() is not despine
if despine:
for s in ["bottom", "left"]:
assert display.ax_.spines[s].get_bounds() == (0, 1)