{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# `train_test_split`: get diagnostics when splitting your data\n\nThis example illustrates the motivation and the use of skore's\n:func:`skore.train_test_split` to get assistance when developing ML/DS projects.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Train-test split in scikit-learn\n\nScikit-learn has a function for splitting the data into train and test\nsets: :func:`sklearn.model_selection.train_test_split`.\nIts signature is the following:\n\n```python\nsklearn.model_selection.train_test_split(\n    *arrays,\n    test_size=None,\n    train_size=None,\n    random_state=None,\n    shuffle=True,\n    stratify=None\n)\n```\nwhere ``*arrays`` is a Python ``*args`` (it allows you to pass a varying number of\npositional arguments) and the scikit-learn doc indicates that it is ``a sequence of\nindexables with same length / shape[0]``.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let us construct a design matrix ``X`` and target ``y`` to illustrate our point:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n\nX = np.arange(10).reshape((5, 2))\ny = np.arange(5)\nprint(f\"{X = }\\n{y = }\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In scikit-learn, the most common usage is the following:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.model_selection import train_test_split as sklearn_train_test_split\n\nX_train, X_test, y_train, y_test = sklearn_train_test_split(\n    X, y, test_size=0.2, random_state=0\n)\nprint(f\"{X_train = }\\n{y_train = }\\n{X_test = }\\n{y_test = }\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Notice the shuffling that is done by default.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In scikit-learn, the user cannot explicitly set the design matrix ``X`` and\nthe target ``y``. The following:\n\n```python\nX_train, X_test, y_train, y_test = sklearn_train_test_split(\n    X=X, y=y, test_size=0.2, random_state=0)\n```\nwould return:\n\n```python\nTypeError: got an unexpected keyword argument 'X'\n```\nIn general, in Python, keyword arguments are useful to prevent typos. For example,\nin the following, ``X`` and ``y`` are reversed:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X_train, X_test, y_train, y_test = sklearn_train_test_split(\n    y, X, test_size=0.2, random_state=0\n)\nprint(f\"{X_train = }\\n{y_train = }\\n{X_test = }\\n{y_test = }\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "but Python will not catch this mistake for us.\nThis is where skore comes in handy.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Train-test split in skore\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Skore has its own :func:`skore.train_test_split` that wraps scikit-learn's\n:func:`sklearn.model_selection.train_test_split`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = np.arange(10_000).reshape((5_000, 2))\ny = [0] * 2_500 + [1] * 2_500"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Expliciting the positional arguments for ``X`` and ``y``\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First of all, naturally, it can be used as a simple drop-in replacement for\nscikit-learn:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import skore\n\nX_train, X_test, y_train, y_test = skore.train_test_split(\n    X, y, test_size=0.2, random_state=0\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>The outputs of :func:`skore.train_test_split` are intentionally exactly the same as\n  :func:`sklearn.model_selection.train_test_split`, so the user can just use the\n  skore version as a drop-in replacement of scikit-learn.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Contrary to scikit-learn, skore allows users to explicit the ``X`` and ``y``, making\ndetection of potential issues easier:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X_train, X_test, y_train, y_test = skore.train_test_split(\n    X=X, y=y, test_size=0.2, random_state=0\n)\nX_train_explicit = X_train.copy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Moreover, when passing ``X`` and ``y`` explicitly, the ``X``'s are always returned\nbefore the ``y``'s, even when they are inverted:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "arr = X.copy()\narr_train, arr_test, X_train, X_test, y_train, y_test = skore.train_test_split(\n    arr, y=y, X=X, test_size=0.2, random_state=0\n)\nX_train_explicit_inverted = X_train.copy()\n\nprint(\"When expliciting, with the small typo, are the `X_train`'s still the same?\")\nprint(np.allclose(X_train_explicit, X_train_explicit_inverted))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### Returning a dictionary instead of positional arguments\nThe default behaviour of outputting a tuple of arrays can be cumbersome and\nerror-prone, in particular when passing them to an :class:`~skore.EstimatorReport`.\nThe new `as_dict` parameter makes the output a dictionary, which makes this simpler:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.linear_model import LogisticRegression\nfrom skore import EstimatorReport\n\nsplit_data = skore.train_test_split(X=X, y=y, random_state=42, as_dict=True)\nsplit_data.keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "estimator = LogisticRegression(random_state=42)\nestimator_report = EstimatorReport(estimator, **split_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Without the dictionary output, this would be written:\n\n```python\nestimator_report = EstimatorReport(\n    estimator,\n    X_train=X_train,\n    y_train=y_train,\n    X_test=X_test,\n    y_test=y_test,\n)\n```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Automatic diagnostics: raising methodological warnings\n\nIn this section, we show how skore can provide methodological checks.\n\n#### Class imbalance\n\nIn machine learning, class imbalance (the classes in a dataset are not equally\nrepresented) requires a specific modelling.\nFor example, in a dataset with 95% majority class (class ``1``) and 5% minority class\n(class ``0``), a dummy model that always predicts class ``1`` will have a 95%\naccuracy, while it would be useless for identifying examples of class ``0``.\nHence, it is important to detect when we have class imbalance.\n\nSuppose that we have imbalanced data:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = np.arange(10_000).reshape((5_000, 2))\ny = [0] * 4_000 + [1] * 1_000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In that case, :func:`skore.train_test_split` raises a ``HighClassImbalanceWarning``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X_train, X_test, y_train, y_test = skore.train_test_split(\n    X=X, y=y, test_size=0.2, random_state=0\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Hence, skore recommends the users to take into account this class imbalance, that\nthey might have missed, in their modelling strategy.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Moreover, skore also detects class imbalance with a class that has too few samples\nwith a ``HighClassImbalanceTooFewExamplesWarning``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = np.arange(400).reshape((200, 2))\ny = [0] * 150 + [1] * 50\n\nX_train, X_test, y_train, y_test = skore.train_test_split(\n    X=X, y=y, test_size=0.2, random_state=0\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### Shuffling without a random state\n\nFor [reproducible results across executions](https://scikit-learn.org/stable/common_pitfalls.html#controlling-randomness),\nskore recommends the use of the ``random_state`` parameter when shuffling\n(remember that ``shuffle=True`` by default) with a ``RandomStateUnsetWarning``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = np.arange(10_000).reshape((5_000, 2))\ny = [0] * 2_500 + [1] * 2_500\n\nX_train, X_test, y_train, y_test = skore.train_test_split(X=X, y=y, test_size=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### Time series data\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, let us assume that we have [time series data](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-of-time-series-data):\nthe data is somewhat time-ordered:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pandas as pd\nfrom skrub.datasets import fetch_employee_salaries\n\ndataset = fetch_employee_salaries()\nX, y = dataset.X, dataset.y\nX[\"date_first_hired\"] = pd.to_datetime(X[\"date_first_hired\"], format=\"%m/%d/%Y\")\nX.head(2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can observe that there is a ``date_first_hired`` which is time-based.\n\nAs one can not shuffle time (time only moves in one direction: forward), we\nrecommend using :class:`sklearn.model_selection.TimeSeriesSplit` instead of\n:func:`sklearn.model_selection.train_test_split` (or :func:`skore.train_test_split`)\nwith a ``TimeBasedColumnWarning``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X_train, X_test, y_train, y_test = skore.train_test_split(\n    X, y, random_state=0, shuffle=False\n)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}