{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Adding custom diagnostic checks\n\n`skore` lets you extend the built-in diagnostic checks with your own.\nThis example shows how to write a custom check function and register it\nwith a report via :meth:`~skore.EstimatorReport.add_checks`.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Writing a custom check for a single estimator\n\nWe start by defining a simple check that flags models with a very large\nnumber of features. The check inspects the test data attached to the\nreport. We throw an exception when the test data is not available to avoid\nrunning the check when it is not applicable. The check function is wrapped in a\n:class:`~skore.Check` instance and registered with the report via\n:meth:`~skore.EstimatorReport.add_checks`.\n\nThe `docs_url` argument is optional. When provided as a full URL (starting\nwith ``\"http\"``), it is used as-is. When it is a plain anchor string\nit points to the skore diagnostic user guide. When omitted entirely,\nno documentation link is shown.\n\nWe set the severity to \"tip\" to indicate that this is not an issue to fix,\nbut a cautionary note about the dataset. Severity can also be set to \"issue\" to\nindicate when there is an issue to fix.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom skore import Check, CheckNotApplicable\n\n\nclass CustomCheck1(Check):\n    code = \"CSTM001\"\n    title = \"High feature count\"\n    report_type = \"estimator\"\n    severity = \"tip\"\n    docs_url = \"https://scikit-learn.org/stable/modules/feature_selection.html#feature-selection\"\n\n    def check_function(self, report):\n        \"\"\"Flag when the number of features exceeds a threshold.\"\"\"\n        if report.X_test is None:\n            raise CheckNotApplicable()\n\n        n_features = X.shape[1]\n        if n_features > 50:\n            return (\n                f\"The dataset has {n_features} features which may hurt model performance. \"\n                \"Consider feature selection or dimensionality reduction.\"\n            )\n        return None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Registering the check\n\n:meth:`~skore.EstimatorReport.add_checks` accepts a list of ``Check`` instances,\nand registers them. The next call to :meth:`~skore.EstimatorReport.diagnose` runs\nany newly added checks on top of the built-in checks.\n\nWe can then find the new check in the Tips tab of the diagnostic, along another tip\ninforming us that the dataset is not standardized.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.linear_model import LinearRegression\nfrom skore import evaluate\n\nrng = np.random.default_rng(42)\nX = rng.normal(size=(200, 80))\ny = X[:, 0] + rng.normal(size=200)\n\nreport = evaluate(LinearRegression(), X, y)\nreport.add_checks([CustomCheck1()])\nreport.diagnose()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cross-validation level checks\n\n:class:`~skore.CrossValidationReport` and :class:`~skore.ComparisonReport` can also\nreceive custom checks, either ran on the full report or on the component estimator\nreports.\n\nThe `report_type` argument of :class:`~skore.Check` controls the scope of the check.\nLet's write a check that is specific to cross-validation reports: it flags metrics\nwith high variance across splits. We set the severity to \"issue\" to indicate that\nthis is an issue to fix.\n\nWe will corrupt the first fold of the target to illustrate the check.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n\ny_noisy = y.copy()\ny_noisy[: len(y_noisy) // 5] = rng.normal(size=len(y_noisy) // 5)\ncv_report = evaluate(LinearRegression(), X, y_noisy, splitter=5)\n\n\nclass CustomCheck2(Check):\n    code = \"CSTM002\"\n    title = \"High score variance across CV splits\"\n    report_type = \"cross-validation\"\n    docs_url = None\n    severity = \"issue\"\n\n    def check_function(self, report):\n        \"\"\"Flag high score variance across CV splits.\"\"\"\n        frames = [\n            sub_report.metrics.summarize(data_source=\"test\").data\n            for sub_report in report.estimator_reports_\n        ]\n        scores = pd.concat(frames, ignore_index=True)\n\n        high_var_metrics = [\n            metric_name\n            for metric_name, group in scores.groupby(\"metric_verbose_name\")\n            if group[\"score\"].std() > 0.1\n        ]\n\n        if high_var_metrics:\n            return f\"Metrics with high variance: {', '.join(high_var_metrics)}.\"\n        return None\n\n\ncv_report.add_checks([CustomCheck2()])\ncv_report.diagnose()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Aggregating checks across estimator reports\n\nWe can also reuse our first check to run it on the component estimator reports\nand aggregate the results across splits.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cv_report.add_checks([CustomCheck1()])\ncv_report.diagnose()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Similarly, :class:`~skore.ComparisonReport` aggregates checks across its\ncomponent reports.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.ensemble import RandomForestRegressor\n\ncomparison_report = evaluate(\n    [LinearRegression(), RandomForestRegressor()], X, y, splitter=5\n)\ncomparison_report.add_checks([CustomCheck1(), CustomCheck2()])\ncomparison_report.diagnose()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}