ConfusionMatrixDisplay#
- class skore.ConfusionMatrixDisplay(*, confusion_matrix_predict, confusion_matrix_ovr, confusion_matrix_thresholded, report_type, ml_task, data_source, report_pos_label)[source]#
Display the confusion matrix.
- Parameters:
- confusion_matrix_predictpd.DataFrame
Predict-based n x n confusion matrix in long format. It always contains âtrue_labelâ, âpredicted_labelâ, âcountâ, ânormalized_by_trueâ, ânormalized_by_predâ, and ânormalized_by_allâ; it can also contain âsplitâ and âestimatorâ when those dimensions are meaningful for the report.
- confusion_matrix_ovrpd.DataFrame or None
Predict-based one-vs-rest 2x2 confusion matrix in long format for multiclass classification. It has the same columns as
confusion_matrix_predict, plus âlabelâ. None for binary classification.- confusion_matrix_thresholdedpd.DataFrame or None
Per-class one-vs-rest thresholded 2x2 confusion matrix in long format. It has the same columns as
confusion_matrix_predict, plus âthresholdâ and âlabelâ. None when the estimator only supports predict.- report_type{âcomparison-cross-validationâ, âcomparison-estimatorâ, âcross-validationâ, âestimatorâ}
The type of report.
- ml_task{âbinary-classificationâ, âmulticlass-classificationâ}
The machine learning task.
- data_source{âtestâ, âtrainâ}
The data source to use.
- report_pos_labelint, float, bool, str or None
The default positive label for display.
- Attributes:
- confusion_matrix_predictpd.DataFrame
Predict-based confusion matrix data in long format.
- confusion_matrix_ovrpd.DataFrame or None
One-vs-rest confusion matrix data for multiclass classification.
- confusion_matrix_thresholdedpd.DataFrame or None
Thresholded one-vs-rest confusion matrix data.
- report_typeReportType
The type of report.
- ml_taskMLTask
The machine learning task.
- data_sourceDataSource
The data source used to compute the matrix.
- report_pos_labelPositiveLabel
The default positive label for display.
- labelslist
Available class labels.
See also
EstimatorReport.metrics.confusion_matrixCreate this display from a report.
RocCurveDisplayPlot ROC curves for the same classifier.
PrecisionRecallCurveDisplayPlot precision-recall curves.
Notes
For multiclass problems, thresholded views use a one-vs-rest (OvR) binary reformulation for each class label.
When
threshold_valueis a float, the display snaps to the closest threshold available in the stored thresholded data.For cross-validation reports, plotting with
subplot_byother than"split"aggregates counts across folds (mean and standard deviation).Examples
>>> from sklearn.datasets import load_breast_cancer >>> from sklearn.linear_model import LogisticRegression >>> from skore import evaluate >>> X, y = load_breast_cancer(return_X_y=True) >>> classifier = LogisticRegression(max_iter=10_000) >>> report = evaluate(classifier, X, y, splitter=0.2) >>> display = report.metrics.confusion_matrix() >>> display.plot()
- frame(*, normalize=None, threshold_value=None, label=<DEFAULT>)[source]#
Return the confusion matrix as a long format dataframe.
When the inspected classifier has a
predict_probaordecision_functionmethod, the confusion matrix can be displayed at various decision thresholds. This is useful for understanding how the modelâs predictions change as the decision threshold varies. In multiclass, this view is obtained by creating a binary problem for each label in a one-vs-rest fashion. Usethreshold_value="all"to return all available thresholds without filtering.- Parameters:
- normalize{âtrueâ, âpredâ, âallâ}, default=None
Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, raw counts are returned as the âvalueâ column.
- threshold_valuefloat, âallâ, or None, default=None
When None, returns the predict-based n x n confusion matrix. When âallâ, returns the thresholded OvR data at all thresholds. When a float, returns the thresholded OvR data at the closest available threshold.
- labelint, float, bool, str or None, default=report pos_label
The class to select. Use None to select all classes.
- Returns:
- framepandas.DataFrame
The confusion matrix as a dataframe with a âvalueâ column and optional metadata columns such as âthresholdâ, âlabelâ, âsplitâ, and âestimatorâ.
- plot(*, normalize=None, threshold_value=None, subplot_by='auto', label=<DEFAULT>)[source]#
Plot the confusion matrix.
When the inspected classifier has a
predict_probaordecision_functionmethod, the confusion matrix can be displayed at various decision thresholds. This is useful for understanding how the modelâs predictions change as the decision threshold varies. In multiclass, this view is obtained by creating a binary problem for each label in a one-vs-rest fashion.- Parameters:
- normalize{âtrueâ, âpredâ, âallâ}, default=None
Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, the confusion matrix will not be normalized.
- threshold_valuefloat or None, default=None
When None, plots the predict-based n x n confusion matrix. When a float, plots the thresholded 2x2 confusion matrix at the closest available threshold for
label. This is obtained in multiclass by creating a binary problem for the label in a one-vs-rest fashion.- subplot_by{âsplitâ, âestimatorâ, âautoâ} or None, default=âautoâ
The variable to use for subplotting. If None, the confusion matrix will not be subplotted. If âautoâ, the variable will be automatically determined based on the report type.
- labelint, float, bool, str or None, default=report pos_label
The class to consider as positive. In multiclass, the predict-based and thresholded views are shown in a one-vs-rest fashion for this label.
- Returns:
- matplotlib.figure.Figure
Figure containing the confusion matrix.
- set_style(*, policy='update', heatmap_kwargs=None, facet_grid_kwargs=None)[source]#
Set the style parameters for the display.
- Parameters:
- policyLiteral[âoverrideâ, âupdateâ], default=âupdateâ
Policy to use when setting the style parameters. If âoverrideâ, existing settings are set to the provided values. If âupdateâ, existing settings are not changed; only settings that were previously unset are changed.
- heatmap_kwargsdict, default=None
Additional keyword arguments to be passed to
seaborn.heatmap().- facet_grid_kwargsdict, default=None
Additional keyword arguments to be passed to
seaborn.FacetGrid.
- Returns:
- None
- Raises:
- ValueError
If a style parameter is unknown.
- static style_plot(plot_func)[source]#
Apply consistent style to skore displays.
This decorator: 1. Applies default style settings 2. Runs
plot_funcunderplt.ioff()so figures are not shown until returned 3. CallsFigure.tight_layout()on the returned figure when applicable 4. Restores the original style settings- Parameters:
- plot_funccallable
The plot function to be decorated.
- Returns:
- callable
The decorated plot function.
Gallery examples#
EstimatorReport: Get insights from any scikit-learn estimator