ImpurityDecreaseDisplay#

class skore.ImpurityDecreaseDisplay(*, importances, report_type)[source]#

Display to inspect the Mean Decrease in Impurity (MDI) of tree-based models.

Parameters:
importancesDataFrame

The importances data to display. The columns are:

  • estimator

  • feature

  • importances

report_type{ā€œestimatorā€, ā€œcross-validationā€, ā€œcomparison-estimatorā€, ā€œcomparison-cross-validationā€}

Report type from which the display is created.

Attributes:
ax_matplotlib Axes

Matplotlib Axes with the plot.

facet_seaborn FacetGrid

FacetGrid containing the plot.

figure_matplotlib Figure

Figure containing the plot.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.ensemble import RandomForestClassifier
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(
...     RandomForestClassifier(random_state=0), **split_data
... )
>>> display = report.inspection.impurity_decrease()
>>> display.frame()
                feature  importances
0  sepal length (cm)     0.1...
1   sepal width (cm)     0.0...
2  petal length (cm)     0.4...
3   petal width (cm)     0.3...
frame()[source]#

Get the mean decrease in impurity in a dataframe format.

The returned dataframe is not going to contain constant columns or columns containing only NaN values.

Returns:
DataFrame

Dataframe containing the mean decrease in impurity of the tree-based model.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.ensemble import RandomForestClassifier
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(
...     RandomForestClassifier(random_state=0), **split_data
... )
>>> display = report.inspection.impurity_decrease()
>>> display.frame()
             feature  importances
0  sepal length (cm)     0.1...
1   sepal width (cm)     0.0...
2  petal length (cm)     0.4...
3   petal width (cm)     0.3...
help()[source]#

Display available attributes and methods using rich.

plot()[source]#

Plot the mean decrease in impurity for the different features.

Examples

>>> from sklearn.datasets import load_iris
>>> from sklearn.ensemble import RandomForestClassifier
>>> from skore import EstimatorReport, train_test_split
>>> iris = load_iris(as_frame=True)
>>> X, y = iris.data, iris.target
>>> y = iris.target_names[y]
>>> split_data = train_test_split(
...     X=X, y=y, random_state=0, as_dict=True, shuffle=True
... )
>>> report = EstimatorReport(RandomForestClassifier(), **split_data)
>>> display = report.inspection.impurity_decrease()
>>> display.plot()
set_style(*, policy='update', barplot_kwargs=None, stripplot_kwargs=None, boxplot_kwargs=None)[source]#

Set the style parameters for the display.

Parameters:
policy{ā€œoverrideā€, ā€œupdateā€}, default=ā€updateā€

Policy to use when setting the style parameters. If ā€œoverrideā€, existing settings are set to the provided values. If ā€œupdateā€, existing settings are not changed; only settings that were previously unset are changed.

barplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.barplot() for rendering the mean decrease in impurity with an EstimatorReport.

stripplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.stripplot() for rendering the mean decrease in impurity with a CrossValidationReport.

boxplot_kwargsdict, default=None

Keyword arguments to be passed to seaborn.boxplot() for rendering the mean decrease in impurity with a CrossValidationReport.

Returns:
selfobject

The instance with a modified style.

Raises:
ValueError

If a style parameter is unknown.

static style_plot(plot_func)[source]#

Apply consistent style to skore displays.

This decorator: 1. Applies default style settings 2. Executes plot_func 3. Calls plt.tight_layout() to make sure axis does not overlap 4. Restores the original style settings

Parameters:
plot_funccallable

The plot function to be decorated.

Returns:
callable

The decorated plot function.