ConfusionMatrixDisplay#

class skore.ConfusionMatrixDisplay(*, confusion_matrix_predict, confusion_matrix_ovr, confusion_matrix_thresholded, report_type, ml_task, data_source, report_pos_label)[source]#

Display the confusion matrix.

Parameters:
confusion_matrix_predictpd.DataFrame

Predict-based n x n confusion matrix in long format. It always contains “true_label”, “predicted_label”, “count”, “normalized_by_true”, “normalized_by_pred”, and “normalized_by_all”; it can also contain “split” and “estimator” when those dimensions are meaningful for the report.

confusion_matrix_ovrpd.DataFrame or None

Predict-based one-vs-rest 2x2 confusion matrix in long format for multiclass classification. It has the same columns as confusion_matrix_predict, plus “label”. None for binary classification.

confusion_matrix_thresholdedpd.DataFrame or None

Per-class one-vs-rest thresholded 2x2 confusion matrix in long format. It has the same columns as confusion_matrix_predict, plus “threshold” and “label”. None when the estimator only supports predict.

report_type{“comparison-cross-validation”, “comparison-estimator”, “cross-validation”, “estimator”}

The type of report.

ml_task{“binary-classification”, “multiclass-classification”}

The machine learning task.

data_source{“test”, “train”}

The data source to use.

report_pos_labelint, float, bool, str or None

The default positive label for display.

Attributes:
confusion_matrix_predictpd.DataFrame

Predict-based confusion matrix data in long format.

confusion_matrix_ovrpd.DataFrame or None

One-vs-rest confusion matrix data for multiclass classification.

confusion_matrix_thresholdedpd.DataFrame or None

Thresholded one-vs-rest confusion matrix data.

report_typeReportType

The type of report.

ml_taskMLTask

The machine learning task.

data_sourceDataSource

The data source used to compute the matrix.

report_pos_labelPositiveLabel

The default positive label for display.

labelslist

Available class labels.

See also

EstimatorReport.metrics.confusion_matrix

Create this display from a report.

RocCurveDisplay

Plot ROC curves for the same classifier.

PrecisionRecallCurveDisplay

Plot precision-recall curves.

Notes

For multiclass problems, thresholded views use a one-vs-rest (OvR) binary reformulation for each class label.

When threshold_value is a float, the display snaps to the closest threshold available in the stored thresholded data.

For cross-validation reports, plotting with subplot_by other than "split" aggregates counts across folds (mean and standard deviation).

Examples

>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import evaluate
>>> X, y = load_breast_cancer(return_X_y=True)
>>> classifier = LogisticRegression(max_iter=10_000)
>>> report = evaluate(classifier, X, y, splitter=0.2)
>>> display = report.metrics.confusion_matrix()
>>> display.plot()
frame(*, normalize=None, threshold_value=None, label=<DEFAULT>)[source]#

Return the confusion matrix as a long format dataframe.

When the inspected classifier has a predict_proba or decision_function method, the confusion matrix can be displayed at various decision thresholds. This is useful for understanding how the model’s predictions change as the decision threshold varies. In multiclass, this view is obtained by creating a binary problem for each label in a one-vs-rest fashion. Use threshold_value="all" to return all available thresholds without filtering.

Parameters:
normalize{‘true’, ‘pred’, ‘all’}, default=None

Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, raw counts are returned as the “value” column.

threshold_valuefloat, “all”, or None, default=None

When None, returns the predict-based n x n confusion matrix. When “all”, returns the thresholded OvR data at all thresholds. When a float, returns the thresholded OvR data at the closest available threshold.

labelint, float, bool, str or None, default=report pos_label

The class to select. Use None to select all classes.

Returns:
framepandas.DataFrame

The confusion matrix as a dataframe with a “value” column and optional metadata columns such as “threshold”, “label”, “split”, and “estimator”.

help()[source]#

Display display help using rich or HTML.

plot(*, normalize=None, threshold_value=None, subplot_by='auto', label=<DEFAULT>)[source]#

Plot the confusion matrix.

When the inspected classifier has a predict_proba or decision_function method, the confusion matrix can be displayed at various decision thresholds. This is useful for understanding how the model’s predictions change as the decision threshold varies. In multiclass, this view is obtained by creating a binary problem for each label in a one-vs-rest fashion.

Parameters:
normalize{‘true’, ‘pred’, ‘all’}, default=None

Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, the confusion matrix will not be normalized.

threshold_valuefloat or None, default=None

When None, plots the predict-based n x n confusion matrix. When a float, plots the thresholded 2x2 confusion matrix at the closest available threshold for label. This is obtained in multiclass by creating a binary problem for the label in a one-vs-rest fashion.

subplot_by{“split”, “estimator”, “auto”} or None, default=”auto”

The variable to use for subplotting. If None, the confusion matrix will not be subplotted. If “auto”, the variable will be automatically determined based on the report type.

labelint, float, bool, str or None, default=report pos_label

The class to consider as positive. In multiclass, the predict-based and thresholded views are shown in a one-vs-rest fashion for this label.

Returns:
matplotlib.figure.Figure

Figure containing the confusion matrix.

set_style(*, policy='update', heatmap_kwargs=None, facet_grid_kwargs=None)[source]#

Set the style parameters for the display.

Parameters:
policyLiteral[“override”, “update”], default=”update”

Policy to use when setting the style parameters. If “override”, existing settings are set to the provided values. If “update”, existing settings are not changed; only settings that were previously unset are changed.

heatmap_kwargsdict, default=None

Additional keyword arguments to be passed to seaborn.heatmap().

facet_grid_kwargsdict, default=None

Additional keyword arguments to be passed to seaborn.FacetGrid.

Returns:
None
Raises:
ValueError

If a style parameter is unknown.

static style_plot(plot_func)[source]#

Apply consistent style to skore displays.

This decorator: 1. Applies default style settings 2. Runs plot_func under plt.ioff() so figures are not shown until returned 3. Calls Figure.tight_layout() on the returned figure when applicable 4. Restores the original style settings

Parameters:
plot_funccallable

The plot function to be decorated.

Returns:
callable

The decorated plot function.