{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Installing some necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install ipywidgets\n",
    "!jupyter nbextension enable --py widgetsnbextension\n",
    "!jupyter labextension install @jupyter-widgets/jupyterlab-manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install xgboost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**It is necessary to change the working directory so the project structure works properly:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From this point, it's on you!\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from ml.data_source.spreadsheet import Spreadsheet\n",
    "from ml.preprocessing.preprocessing import Preprocessing\n",
    "from ml.model.trainer import TrainerSklearn\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = Spreadsheet().get_data('../../../data/raw/train.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Survived', 'Pclass', 'Sex', 'Age'], dtype='object')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = Preprocessing()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cleaning data\n",
      "Category encoding\n"
     ]
    }
   ],
   "source": [
    "df = p.clean_data(df)\n",
    "df = p.categ_encoding(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Age</th>\n",
       "      <th>Pclass_1</th>\n",
       "      <th>Pclass_2</th>\n",
       "      <th>Pclass_3</th>\n",
       "      <th>Sex_female</th>\n",
       "      <th>Sex_male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived   Age  Pclass_1  Pclass_2  Pclass_3  Sex_female  Sex_male\n",
       "0         0  22.0         0         0         1           0         1\n",
       "1         1  38.0         1         0         0           1         0\n",
       "2         1  26.0         0         0         1           1         0\n",
       "3         1  35.0         1         0         0           1         0\n",
       "4         0  35.0         0         0         1           0         1"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df.drop(columns=[\"Survived\"])\n",
    "y = df[\"Survived\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((499, 6), (215, 6), (499,), (215,))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ensure the same random state passed to TrainerSkleran().train()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)\n",
    "X_train.shape, X_test.shape, y_train.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf = TrainerSklearn().train(X, y, classification=True, \n",
    "                            algorithm=RandomForestClassifier, \n",
    "                            preprocessing=p,\n",
    "                           data_split=('train_test', {'test_size':.3}),\n",
    "                           random_state=123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.8186046511627907,\n",
       " 'f1': 0.7607361963190185,\n",
       " 'precision': 0.7654320987654321,\n",
       " 'recall': 0.7560975609756098,\n",
       " 'roc_auc': 0.8644782688428387}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf.get_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Age', 'Pclass_1', 'Pclass_2', 'Pclass_3', 'Sex_female', 'Sex_male']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf.get_columns()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.        , 0.4920119 , 0.713     , 0.92      , 0.23      ,\n",
       "       0.00416667, 0.10548385, 0.        , 0.18766667, 1.        ,\n",
       "       0.03      , 0.        , 0.        , 0.97      , 0.85716667,\n",
       "       0.004     , 0.        , 0.49520155, 1.        , 0.00833333,\n",
       "       0.2827619 , 1.        , 0.34028571, 0.88516667, 0.47416667,\n",
       "       1.        , 0.01916342, 0.995     , 0.        , 0.28569833,\n",
       "       0.        , 0.        , 1.        , 0.28569833, 0.        ,\n",
       "       0.99      , 0.49520155, 0.        , 0.37355159, 0.        ,\n",
       "       0.03      , 0.        , 0.23487302, 0.03      , 1.        ,\n",
       "       0.42838095, 0.49935714, 1.        , 1.        , 0.27207937,\n",
       "       0.10548385, 0.98      , 0.00857143, 0.67      , 1.        ,\n",
       "       0.49520155, 1.        , 0.        , 0.        , 0.        ,\n",
       "       1.        , 1.        , 0.10548385, 0.        , 0.46297619,\n",
       "       0.        , 0.        , 1.        , 0.02      , 0.49520155,\n",
       "       0.10548385, 0.67      , 0.114     , 0.34859524, 0.551     ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.00833333,\n",
       "       0.37355159, 0.99      , 1.        , 0.09811905, 0.23487302,\n",
       "       1.        , 1.        , 1.        , 0.14030134, 0.0945303 ,\n",
       "       0.        , 0.713     , 0.        , 0.24386114, 0.98      ,\n",
       "       0.87      , 0.33694048, 1.        , 0.27242063, 0.98      ,\n",
       "       0.        , 1.        , 0.13428904, 0.22966667, 0.72997078,\n",
       "       0.03158333, 1.        , 0.        , 0.01572172, 0.        ,\n",
       "       0.        , 0.05366667, 1.        , 1.        , 0.87      ,\n",
       "       0.24386114, 0.46297619, 0.        , 0.        , 0.        ,\n",
       "       0.99      , 0.14030134, 0.44516667, 0.72997078, 0.        ,\n",
       "       0.        , 0.02      , 0.88516667, 0.28569833, 0.59893651,\n",
       "       0.70592316, 0.995     , 0.84156374, 0.25      , 0.        ,\n",
       "       0.        , 0.        , 0.00416667, 0.        , 0.09085714,\n",
       "       0.68168939, 0.02      , 1.        , 0.10548385, 0.        ,\n",
       "       1.        , 0.        , 0.13428904, 0.96      , 0.07      ,\n",
       "       0.99030303, 1.        , 0.15074706, 0.35530952, 0.        ,\n",
       "       0.15074706, 0.40907143, 0.24386114, 0.        , 0.        ,\n",
       "       0.02261905, 1.        , 0.03702789, 0.        , 0.        ,\n",
       "       0.59297547, 0.99      , 0.072     , 0.        , 0.02      ,\n",
       "       0.665     , 0.18766667, 1.        , 0.        , 0.        ,\n",
       "       1.        , 0.49190873, 0.995     , 1.        , 0.14030134,\n",
       "       0.98      , 0.        , 0.        , 0.995     , 1.        ,\n",
       "       1.        , 0.62      , 0.31      , 0.99      , 0.12246981,\n",
       "       0.        , 0.51754762, 0.        , 0.01      , 0.22966667,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.70592316,\n",
       "       1.        , 0.76      , 1.        , 0.62689683, 1.        ,\n",
       "       0.16      , 1.        , 0.14030134, 0.        , 0.63216667,\n",
       "       0.13866667, 1.        , 0.        , 0.31666667, 0.        ])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf.predict_proba(X_test, binary=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Predicting new data\n",
    "def predict_new(X, model, probs=True):\n",
    "    X = p.clean_data(X)\n",
    "    X = p.categ_encoding(X)\n",
    "    \n",
    "    columns = model.get_columns()\n",
    "    for col in columns:\n",
    "        if col not in X.columns:\n",
    "            X[col] = 0\n",
    "    print(X)\n",
    "    if probs:\n",
    "        return model.predict_proba(X)\n",
    "    else:\n",
    "        return model.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Pclass   Sex  Age\n",
       "0       3  male    4"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_data = pd.DataFrame({\n",
    "    'Pclass':3,\n",
    "    'Sex': 'male',\n",
    "    'Age':4\n",
    "}, index=[0])\n",
    "\n",
    "new_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cleaning data\n",
      "Category encoding\n",
      "   Age  Pclass_3  Sex_male  Pclass_1  Pclass_2  Sex_female\n",
      "0    4         1         1         0         0           0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.65140476, 0.34859524]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_new(new_data, rf)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
