CoefficientsDisplay#

class skore.CoefficientsDisplay(*, coefficients, report_type)[source]#

Display to inspect the coefficients of linear models.

Parameters:
coefficientsDataFrame | list[DataFrame]

The coefficients data to display. The columns are:

  • estimator

  • split

  • feature

  • label or output (classification vs. regression)

  • coefficients

report_type{“estimator”, “cross-validation”, “comparison-estimator”, “comparison-cross-validation”}

Report type from which the display is created.

Attributes:
ax_ndarray ofmatplotlib Axes

Array of matplotlib Axes with the different matplotlib axis.

figure_matplotlib Figure

Figure containing the plot.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(LogisticRegression(), **split_data)
>>> display = report.feature_importance.coefficients()
>>> display.frame()
              feature       label  coefficients
0           Intercept      setosa      9.2...
1   sepal length (cm)      setosa     -0.4...
2    sepal width (cm)      setosa      0.8...
3   petal length (cm)      setosa     -2.3...
4    petal width (cm)      setosa     -0.9...
5           Intercept  versicolor      1.7...
6   sepal length (cm)  versicolor      0.5...
7    sepal width (cm)  versicolor     -0.2...
8   petal length (cm)  versicolor     -0.2...
9    petal width (cm)  versicolor     -0.7...
10          Intercept   virginica    -11.0...
11  sepal length (cm)   virginica     -0.1...
12   sepal width (cm)   virginica     -0.5...
13  petal length (cm)   virginica      2.5...
14   petal width (cm)   virginica      1.7...
frame(*, include_intercept=True)[source]#

Get the coefficients in a dataframe format.

The returned dataframe is not going to contain constant columns or columns containing only NaN values.

Parameters:
include_interceptbool, default=True

Whether or not to include the intercept in the dataframe.

Returns:
DataFrame

Dataframe containing the coefficients of the linear model.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(LogisticRegression(), **split_data)
>>> display = report.feature_importance.coefficients()
>>> display.frame()
            feature       label  coefficients
0           Intercept      setosa      9.2...
1   sepal length (cm)      setosa     -0.4...
2    sepal width (cm)      setosa      0.8...
3   petal length (cm)      setosa     -2.3...
4    petal width (cm)      setosa     -0.9...
5           Intercept  versicolor      1.7...
6   sepal length (cm)  versicolor      0.5...
7    sepal width (cm)  versicolor     -0.2...
8   petal length (cm)  versicolor     -0.2...
9    petal width (cm)  versicolor     -0.7...
10          Intercept   virginica    -11.0...
11  sepal length (cm)   virginica     -0.1...
12   sepal width (cm)   virginica     -0.5...
13  petal length (cm)   virginica      2.5...
14   petal width (cm)   virginica      1.7...
help()[source]#

Display available attributes and methods using rich.

plot(*, include_intercept=True, subplot_by='auto')[source]#

Plot the coefficients for the different features.

Parameters:
include_interceptbool, default=True

Whether or not to include the intercept in the dataframe.

subplot_by{“auto”, “estimator”, “label”, “output”} or None, default=”auto”

The column to use for subplotting and dividing the coefficients into subplots. If “auto”, not subplotting is performed apart from:

  • when comparing estimators in a multiclass classification or multi-output regression problem;

  • when comparing estimators for which the input features are different.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(LogisticRegression(), **split_data)
>>> display = report.feature_importance.coefficients()
>>> display.plot()
set_style(*, policy='update', barplot_kwargs=None, boxplot_kwargs=None, stripplot_kwargs=None)[source]#

Set the style parameters for the display.

Parameters:
policyLiteral[“override”, “update”], default=”update”

Policy to use when setting the style parameters. If “override”, existing settings are set to the provided values. If “update”, existing settings are not changed; only settings that were previously unset are changed.

barplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.barplot() for rendering the coefficients with an EstimatorReport or ComparisonReport of EstimatorReport.

boxplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.boxplot() for rendering the coefficients with a CrossValidationReport or ComparisonReport of CrossValidationReport.

stripplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.stripplot() for rendering the coefficients with a CrossValidationReport or ComparisonReport of CrossValidationReport.

Returns:
selfobject

The instance with a modified style.

Raises:
ValueError

If a style parameter is unknown.

static style_plot(plot_func)[source]#

Apply consistent style to skore displays.

This decorator: 1. Applies default style settings 2. Executes plot_func 3. Calls plt.tight_layout() to make sure axis does not overlap 4. Restores the original style settings

Parameters:
plot_funccallable

The plot function to be decorated.

Returns:
callable

The decorated plot function.