{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-label data stratification\n",
    "\n",
    "With the development of more complex multi-label transformation methods the community realizes how much the quality of classification depends on how the data is split into train/test sets or into folds for parameter estimation. More questions appear on stackoverflow or [crossvalidated](https://datascience.stackexchange.com/questions/33076/how-can-i-perform-stratified-sampling-for-multi-label-multi-class-classification) concerning methods for multi-label stratification.\n",
    "\n",
    "For many reasons, described [here](http://lpis.csd.auth.gr/publications/sechidis-ecmlpkdd-2011.pdf) and [here](http://proceedings.mlr.press/v74/szyma%C5%84ski17a.html) traditional single-label approaches to stratifying data fail to provide balanced data set divisions which prevents classifiers from generalizing information. \n",
    "\n",
    "Some train/test splits don't include evidence for a given label at all in the train set. others disproportionately put even as much as 70% of label pair evidence in the test set, leaving the train set without proper evidence for generalizing conditional probabilities for label relations.\n",
    "\n",
    "You can also watch a great video presentation from ECML 2011 which explains this in depth:\n",
    "\n",
    "<blockquote>\n",
    "<a href='http://videolectures.net/ecmlpkdd2011_tsoumakas_stratification/'>\n",
    "  <img src='http://videolectures.net/ecmlpkdd2011_tsoumakas_stratification/thumb.jpg' border=0 />\n",
    "  <br/>On the Stratification of Multi-Label Data</a><br/>\n",
    "Grigorios Tsoumakas\n",
    "</blockquote>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scikit-multilearn provides an implementation of iterative stratification which aims to provide well-balanced distribution of evidence of label relations up to a given order. To see what it means, let's load up some data. We'll be using the scene data set, both in divided and undivided variants, to illustrate the problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scene:undivided - exists, not redownloading\n"
     ]
    }
   ],
   "source": [
    "from skmultilearn.dataset import load_dataset\n",
    "X,y, _, _ = load_dataset('scene', 'undivided')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at how many examples are available per label combination:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 264,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skmultilearn.model_selection.measures import get_combination_wise_output_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 265,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({(0, 0): 427,\n",
       "         (0, 3): 1,\n",
       "         (0, 4): 38,\n",
       "         (0, 5): 19,\n",
       "         (1, 1): 364,\n",
       "         (2, 2): 397,\n",
       "         (2, 3): 24,\n",
       "         (2, 4): 14,\n",
       "         (3, 3): 433,\n",
       "         (3, 4): 76,\n",
       "         (3, 5): 6,\n",
       "         (4, 4): 533,\n",
       "         (4, 5): 1,\n",
       "         (5, 5): 431})"
      ]
     },
     "execution_count": 265,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(combination for row in get_combination_wise_output_matrix(y.A, order=2) for combination in row)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load up the original division, to see how the set was split into train/test data in 2004, before multi-label stratification methods appeared."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scene:train - exists, not redownloading\n",
      "scene:test - exists, not redownloading\n"
     ]
    }
   ],
   "source": [
    "_, original_y_train, _, _ = load_dataset('scene', 'train')\n",
    "_, original_y_test, _, _ = load_dataset('scene', 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>(0, 0)</th>\n",
       "      <th>(0, 3)</th>\n",
       "      <th>(0, 4)</th>\n",
       "      <th>(0, 5)</th>\n",
       "      <th>(1, 1)</th>\n",
       "      <th>(2, 2)</th>\n",
       "      <th>(2, 3)</th>\n",
       "      <th>(2, 4)</th>\n",
       "      <th>(3, 3)</th>\n",
       "      <th>(3, 4)</th>\n",
       "      <th>(3, 5)</th>\n",
       "      <th>(4, 4)</th>\n",
       "      <th>(4, 5)</th>\n",
       "      <th>(5, 5)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>test</th>\n",
       "      <td>200.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>199.0</td>\n",
       "      <td>200.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>237.0</td>\n",
       "      <td>49.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>207.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>train</th>\n",
       "      <td>227.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>165.0</td>\n",
       "      <td>197.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>196.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>277.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>224.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       (0, 0)  (0, 3)  (0, 4)  (0, 5)  (1, 1)  (2, 2)  (2, 3)  (2, 4)  (3, 3)  \\\n",
       "test    200.0     1.0    17.0     7.0   199.0   200.0    16.0     8.0   237.0   \n",
       "train   227.0     0.0    21.0    12.0   165.0   197.0     8.0     6.0   196.0   \n",
       "\n",
       "       (3, 4)  (3, 5)  (4, 4)  (4, 5)  (5, 5)  \n",
       "test     49.0     5.0   256.0     0.0   207.0  \n",
       "train    27.0     1.0   277.0     1.0   224.0  "
      ]
     },
     "execution_count": 268,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\n",
    "    'train': Counter(str(combination) for row in get_combination_wise_output_matrix(original_y_train.A, order=2) for combination in row), \n",
    "    'test' : Counter(str(combination) for row in get_combination_wise_output_matrix(original_y_test.A, order=2) for combination in row)\n",
    "}).T.fillna(0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1211, 1196)"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "original_y_train.shape[0], original_y_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the split size is nearly identical, yet some label combination evidence is well balanced between the splits. While this is a toy case on a small data set, such phenomena are common in larger datasets. We would like to fix this. \n",
    "\n",
    "Let's load the iterative stratifier and divided the set again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skmultilearn.model_selection import iterative_train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size = 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>(0, 0)</th>\n",
       "      <th>(0, 3)</th>\n",
       "      <th>(0, 4)</th>\n",
       "      <th>(0, 5)</th>\n",
       "      <th>(1, 1)</th>\n",
       "      <th>(2, 2)</th>\n",
       "      <th>(2, 3)</th>\n",
       "      <th>(2, 4)</th>\n",
       "      <th>(3, 3)</th>\n",
       "      <th>(3, 4)</th>\n",
       "      <th>(3, 5)</th>\n",
       "      <th>(4, 4)</th>\n",
       "      <th>(4, 5)</th>\n",
       "      <th>(5, 5)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>test</th>\n",
       "      <td>213.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>182.0</td>\n",
       "      <td>199.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>38.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>267.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>215.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>train</th>\n",
       "      <td>214.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>182.0</td>\n",
       "      <td>198.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>216.0</td>\n",
       "      <td>38.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>266.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>216.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       (0, 0)  (0, 3)  (0, 4)  (0, 5)  (1, 1)  (2, 2)  (2, 3)  (2, 4)  (3, 3)  \\\n",
       "test    213.0     0.0    19.0     9.0   182.0   199.0    12.0     7.0   217.0   \n",
       "train   214.0     1.0    19.0    10.0   182.0   198.0    12.0     7.0   216.0   \n",
       "\n",
       "       (3, 4)  (3, 5)  (4, 4)  (4, 5)  (5, 5)  \n",
       "test     38.0     3.0   267.0     1.0   215.0  \n",
       "train    38.0     3.0   266.0     0.0   216.0  "
      ]
     },
     "execution_count": 279,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\n",
    "    'train': Counter(str(combination) for row in get_combination_wise_output_matrix(y_train.A, order=2) for combination in row), \n",
    "    'test' : Counter(str(combination) for row in get_combination_wise_output_matrix(y_test.A, order=2) for combination in row)\n",
    "}).T.fillna(0.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the new division is much more balanced."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
