{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Cache mechanism\n\nThis example shows how :class:`~skore.EstimatorReport` and\n:class:`~skore.CrossValidationReport` use caching to speed up computations.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Loading some data\n\nFirst, we load a dataset from `skrub`. Our goal is to predict if a company paid a\nphysician. The ultimate goal is to detect potential conflict of interest when it comes\nto the actual problem that we want to solve.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skrub.datasets import fetch_open_payments\n\ndataset = fetch_open_payments()\ndf = dataset.X\ny = dataset.y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skrub import TableReport\n\nTableReport(df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n\nTableReport(pd.DataFrame(y))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The dataset has over 70,000 records with only categorical features.\nSome categories are not well defined.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Caching with :class:`~skore.EstimatorReport` and :class:`~skore.CrossValidationReport`\n\nWe use `skrub` to create a simple predictive model that handles our dataset's\nchallenges.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skrub import tabular_pipeline\n\nmodel = tabular_pipeline(\"classifier\")\nmodel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This model handles all types of data: numbers, categories, dates, and missing values.\nLet's train it on part of our dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skore import train_test_split\n\nX_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42)\n# Let's keep a completely separate dataset\nX_train, X_external, y_train, y_external = train_test_split(\n    X_train, y_train, random_state=42\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Caching the predictions for fast metric computation\n\nFirst, we focus on :class:`~skore.EstimatorReport`, as the same philosophy will\napply to :class:`~skore.CrossValidationReport`.\n\nLet's explore how :class:`~skore.EstimatorReport` uses caching to speed up\npredictions. We start by training the model:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skore import EstimatorReport\n\nreport = EstimatorReport(\n    model,\n    X_train=X_train,\n    y_train=y_train,\n    X_test=X_test,\n    y_test=y_test,\n    pos_label=\"allowed\",\n)\nreport.help()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We compute the accuracy on our test set and measure how long it takes:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import time\n\nstart = time.time()\nresult = report.metrics.accuracy()\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For comparison, here's how scikit-learn computes the same accuracy score:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import accuracy_score\n\nstart = time.time()\nresult = accuracy_score(report.y_test, report.estimator_.predict(report.X_test))\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Both approaches take similar time.\n\nNow, watch what happens when we compute the accuracy again with our skore estimator\nreport:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\nresult = report.metrics.accuracy()\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The second calculation is instant! This happens because the report saves previous\ncalculations in its cache. Let's look inside the cache:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "report._cache"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The cache stores predictions by type and data source. This means that computing\nmetrics that use the same type of predictions will be faster.\nLet's try the precision metric:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\nresult = report.metrics.precision()\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We observe that it takes only a few milliseconds to compute the precision because we\ndon't need to re-compute the predictions and only have to compute the precision\nmetric itself.\nSince the predictions are the bottleneck in terms of computation time, we observe\nan interesting speedup.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Caching all the possible predictions at once\n\nWe can pre-compute all predictions at once:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "report.cache_predictions()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, all possible predictions are stored. Any metric calculation will be much faster,\neven on different data (like the training set):\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\nresult = report.metrics.log_loss(data_source=\"train\")\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Caching for plotting\n\nThe cache also speeds up plots. Let's create a ROC curve:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\ndisplay = report.metrics.roc()\ndisplay.plot()\nend = time.time()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The second plot is instant because it uses cached data:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\ndisplay = report.metrics.roc()\ndisplay.plot()\nend = time.time()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We only use the cache to retrieve the `display` object and not directly the matplotlib\nfigure. It means that we can still customize the cached plot before displaying it:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "display.set_style(relplot_kwargs={\"color\": \"tab:orange\"})\ndisplay.plot()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Be aware that we can clear the cache if we want to:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "report.clear_cache()\nreport._cache"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It means that nothing is stored anymore in the cache.\n\n### Caching with :class:`~skore.CrossValidationReport`\n\n:class:`~skore.CrossValidationReport` uses the same caching system for each split\nin cross-validation by leveraging the previous :class:`~skore.EstimatorReport`:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from skore import CrossValidationReport\n\nreport = CrossValidationReport(model, X=df, y=y, splitter=5, n_jobs=4)\nreport.help()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since a :class:`~skore.CrossValidationReport` uses many\n:class:`~skore.EstimatorReport`, we will observe the same behaviour as we previously\nexposed.\nThe first call will be slow because it computes the predictions for each split.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\nresult = report.metrics.summarize().frame()\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "But the subsequent calls are fast because the predictions are cached.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "start = time.time()\nresult = report.metrics.summarize().frame()\nend = time.time()\nresult"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Time taken: {end - start:.2f} seconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Hence, we observe the same type of behaviour as we previously exposed.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}