CrossValidationReport.get_predictions#

CrossValidationReport.get_predictions(*, data_source, response_method='predict', X=None, pos_label=<DEFAULT>)[source]

Get estimator’s predictions.

This method has the advantage to reload from the cache if the predictions were already computed in a previous call.

Parameters:
data_source{“test”, “train”}, default=”test”

The data source to use.

  • “test” : use the test set provided when creating the report.

  • “train” : use the train set provided when creating the report.

  • “X_y” : use the train set provided when creating the report and the target variable.

response_method{“predict”, “predict_proba”, “decision_function”}, default=”predict”

The response method to use to get the predictions.

Xarray-like of shape (n_samples, n_features), optional

When data_source is “X_y”, the input features on which to compute the response method.

pos_labelint, float, bool, str or None, default=_DEFAULT

The label to consider as the positive class when computing predictions in binary classification cases. By default, the positive class is set to the one provided when creating the report. If None, estimator_.classes_[1] is used as positive label.

When pos_label is equal to estimator_.classes_[0], it will be equivalent to estimator_.predict_proba(X)[:, 0] for response_method="predict_proba" and -estimator_.decision_function(X) for response_method="decision_function".

Returns:
list of np.ndarray of shape (n_samples,) or (n_samples, n_classes)

The predictions for each cross-validation split.

Raises:
ValueError

If the data source is invalid.

Examples

>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(random_state=42)
>>> estimator = LogisticRegression()
>>> from skore import CrossValidationReport
>>> report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2)
>>> predictions = report.get_predictions(data_source="test")
>>> print([split_predictions.shape for split_predictions in predictions])
[(50,), (50,)]