{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare some real-world data: download data file mda.zip from IS (sources/mda.zip). The data corresponds to an experiment in oncology (breast cancer), in which tens of thousands of genes were profiled and a biomarker for \"pathologic complete response\" was sought. Some details at: https://doi.org/10.1186/bcr2468" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "Xtr = np.load('X-train.npy') # 22283 variables, 130 observations\n", "Ytr = np.load('Y-train.npy') # Ytr[:,0] - ER positive; Ytr[:,1] - pCR\n", "Ytr = Ytr.astype('int32') # make sure the labels are INTs\n", "\n", "Xts = np.load('X-test.npy') # 22283 variables, 100 observations\n", "Yts = np.load('Y-test.npy') # Ytr[:,0] - ER positive; Ytr[:,1] - pCR\n", "Yts = Yts.astype('int32') # make sure the labels are INTs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# AdaBoost and Random Forests classifiers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## AdaBoost\n", "Read the docs: http://scikit-learn.org/stable/modules/ensemble.html and have a look at the examples:\n", " * http://scikit-learn.org/stable/auto_examples/ensemble/plot_adaboost_hastie_10_2.html\n", " * http://scikit-learn.org/stable/auto_examples/ensemble/plot_adaboost_twoclass.html" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.metrics import zero_one_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Classical AdaBoost: discrete AdaBoost algorithm. Weak learner: decision stumps." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "T=200\n", "bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1), n_estimators=T,\n", " algorithm='SAMME')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AdaBoostClassifier(algorithm='SAMME',\n", " base_estimator=DecisionTreeClassifier(class_weight=None,\n", " criterion='gini',\n", " max_depth=1,\n", " max_features=None,\n", " max_leaf_nodes=None,\n", " min_impurity_decrease=0.0,\n", " min_impurity_split=None,\n", " min_samples_leaf=1,\n", " min_samples_split=2,\n", " min_weight_fraction_leaf=0.0,\n", " presort=False,\n", " random_state=None,\n", " splitter='best'),\n", " learning_rate=1.0, n_estimators=200, random_state=None)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit the model (as usual)\n", "bdt.fit(Xtr, Ytr[:,0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result of AdaBoost with decision stumps can be analyzed to find the most important variables from the data set. The follwoing command gives the indexes of variables with importance score higher than a threshold (0.01):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 1, 1262, 1952, 2320, 2584, 5830, 5921, 6915, 7681,\n", " 7746, 8076, 9810, 10643, 12014, 12172, 12820, 12893, 13199,\n", " 14522, 16159, 19896, 22097]),)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.where(bdt.feature_importances_ > 0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the errors, per step:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# - train error\n", "err_tr = np.zeros((T,)) \n", "for i, yp in enumerate(bdt.staged_predict(Xtr)):\n", " err_tr[i] = zero_one_loss(yp, Ytr[:,0])\n", "\n", "# - test error\n", "err_ts = np.zeros((T,)) \n", "for i, yp in enumerate(bdt.staged_predict(Xts)):\n", " err_ts[i] = zero_one_loss(yp, Yts[:,0])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAcFklEQVR4nO3df5RUdf3H8eebRVB+iHxzQwQELBSRDHEljt+S/FGhIejXH0H+TM0fhV81yzCL+ELfjopaeiKVxBSTH6amlHj8ZmoeLdAFQUVECdDlh7CSwK4oCLy/f3xmnNlhl53Zmd179/p6nDNn5t65c+e9d2Zf85nP/dw75u6IiEhytYm6ABERaV4KehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSbi8gt7MhpvZMjNbbmbj6rn/AjOrNrNFqcvFpS9VRESaom1jC5hZGTAF+BqwGnjJzOa4++s5i85297HNUKOIiBQhnxb9EGC5u69w9+3ALGBU85YlIiKl0miLHugBVGVNrwa+VM9yp5vZscCbwNXuXpW7gJldAlwC0LFjx6P69+9feMUiIp9iCxYseM/dywt5TD5Bn48/AzPdfZuZXQrcBxyfu5C7TwWmAlRUVHhlZWWJnl5E5NPBzN4u9DH5dN2sAXplTfdMzfuEu290922pybuBowotREREmkc+Qf8S0M/M+ppZO2A0MCd7ATPrnjU5ElhauhJFRKQYjXbduPsOMxsLPAmUAfe4+xIzmwhUuvsc4L/NbCSwA/g3cEEz1iwiIgWwqE5TrD56EZHCmdkCd68o5DE6MlZEJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmXV9Cb2XAzW2Zmy81s3B6WO93M3MwqSleiiIgUo9GgN7MyYApwEjAAGGNmA+pZrjNwJTC/1EWKiEjT5dOiHwIsd/cV7r4dmAWMqme5ScCNwEclrE9ERIqUT9D3AKqyplen5n3CzAYDvdz98T2tyMwuMbNKM6usrq4uuFgRESlc0TtjzawNcCtwTWPLuvtUd69w94ry8vJin1pERPKQT9CvAXplTfdMzUvrDAwEnjWzVcBQYI52yIqIxEM+Qf8S0M/M+ppZO2A0MCd9p7tvdvf93b2Pu/cB5gEj3b2yWSoWEZGCNBr07r4DGAs8CSwFHnT3JWY20cxGNneBIiJSnLb5LOTuc4G5OfPGN7DsV4svS0RESkVHxoqIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCRcXkFvZsPNbJmZLTezcfXcf5mZvWpmi8zseTMbUPpSRUSkKRoNejMrA6YAJwEDgDH1BPkMd/+Cuw8CbgJuLXmlIiLSJPm06IcAy919hbtvB2YBo7IXcPctWZMdAS9diSIiUoy2eSzTA6jKml4NfCl3ITP7PvADoB1wfEmqExGRopVsZ6y7T3H3zwE/Bn5a3zJmdomZVZpZZXV1dameWkRE9iCfoF8D9Mqa7pma15BZwKn13eHuU929wt0rysvL869SRESaLJ+gfwnoZ2Z9zawdMBqYk72AmfXLmvwm8FbpShQRkWI02kfv7jvMbCzwJFAG3OPuS8xsIlDp7nOAsWZ2IvAx8D5wfnMWLSIi+ctnZyzuPheYmzNvfNbtK0tcl4iIlIiOjBURSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBKudQT99OlQXR11FSIirVL8g766Gs4/Hx54IOpKRERapfgH/ZYt4XrTpmjrEBFppeIf9DU14Tod+CIiUpD4B31tbbhW0IuINEn8g14tehGRosQ/6NWiFxEpSusJ+s2bo61DRKSVin/Qq+tGRKQo8Q96dd2IiBQl/kGvFr2ISFHiH/TpFn1NDezaFW0tIiKtUPyDPt2id8+EvoiI5C3+QZ8d7uq+EREpWPyDPt2iBwW9iEgT5BX0ZjbczJaZ2XIzG1fP/T8ws9fN7BUz+5uZ9S5Zhdkteo2lFxEpWKNBb2ZlwBTgJGAAMMbMBuQs9jJQ4e5HAA8BN5Wswtpa+Oxnw2216EVECpZPi34IsNzdV7j7dmAWMCp7AXd/xt23pibnAT1LVmFNDRx4YLitoBcRKVg+Qd8DqMqaXp2a15CLgCfqu8PMLjGzSjOrrM73F6Nqa6FH6ukU9CIiBSvpzlgzOweoACbXd7+7T3X3CnevKC8vz2+lNTUKehGRIrTNY5k1QK+s6Z6peXWY2YnA9cAwd99Wkup27ICPPoLu3cO0dsaKiBQsnxb9S0A/M+trZu2A0cCc7AXM7EjgLmCku28oWXXpETddukCnTqFFv3mzWvYiIgVoNOjdfQcwFngSWAo86O5LzGyimY1MLTYZ6AT80cwWmdmcBlZXmHTQd+4M++4bAv7MM+Hcc0uyehGRT4N8um5w97nA3Jx547Nun1jiuoL0wVKdOoWgf/99eOGFzHBLERFpVLyPjE236Dt1Ct03lZWwdSu88w58/HG0tYmItBKtI+jTXTfvvBOmd+2CqqqGHyciIp+Id9Dndt1kW7my5eupz+uvQ//+8alHRCRHvIM+t0UP0K1buI5LsD7/PCxbBnfdFXUlIiL1infQ19eiHzECysriE/TpOu69V/sNRCSW4h302S36Ll3C7SFDoHdvWLEiurqyrVwJZrB+PTz+eNTViIjsJt5Bn27Rd+iQadEPHgx9+8arRT9sWDh69+qr4dRTYf78qKsSEflEvIO+tjaEfFkZnHgijBkDRxwRv6Dv1w8mTYKuXeEvf4EZM6KuSkTkE/EO+pqa0G0D8MUvhgBt1y4E/YYN8MEH0dZXWwvV1aGeiy6ChQvhc5+DdeuirUtEJEu8g762NuyIzdW3b7hetapFy9lN+vnT9UDowlHQi0iMxDvos1v02Q4+OFxPmQIPP9yyNWVLdx+l6wEFvYjETnyD3j0cjHTAAbvfd+ih4QPgjjvgjDPglVdavj7IjPzJbdG/+2409YiI1CO+Qf/cc/Cvf8G3v737ffvtF4YzrlgR+uynTWv5+iC06Dt2hP33z8zr3j3sO0iPGBIRiVh8g37atDCk8vTT679/n31CS/q00+D++8MPlLS0lStDDWaZeekfSVH3jYjERDyDftMm+OMf4eyzw/DKPbn44nD64j/9qe78Dz7IhP/OnWGZtHfegSVLMpdtefwg1pYtmSNft28Pj3vzzbrdNpDpalLQi0hMxDPo//znENLf+U7jyx5/PPTps3v3zTe+ken2+cUvwg7TTZtgwYKw/MCBmct3v7vn59i2DQ47DK65Jkxffnl43BtvwCGH1F1WLXoRiZl4Bv2yZeEgqSOPbHzZNm3gwgvhb3/L7Bz96KNwdOpjj4UhkHfdFUJ+5kyYOjV0+8yaBQ8+GLqGZs+GjRsbfo5HH4W1a8P5bNauDeP5Tz0VHnoIrr++7rIKehGJmXgG/YoV0KsXtM3rB7DgggtCP/nvfx+mX3st/LD4rl1wzjkhdDt0gN/+NoT9mWfCt74VrsePD10xDzzQ8PqnTQuPr6mBs84KHyQ//Wn4kOjate6yXbtC+/YKehGJjXgG/cqVdcemN6ZXLxg+PAT9zp3hCFUI3SovvBBObfyLX4QPgJqa0K+fdsQRUFEBd98dhnTmWrUKnnoKfvSjzPoGDQrn3KmPWeinV9CLSEzEN+hzd3I25qKLYM0aePJJePnlMATz5z8P911wQbi0bx/G4P/nf9Z97MUXw6uvhm6X3MugQWGZCy/MfEBcdFHdkTa5muugqXHjwjl1REQKkGffSAvaujWMkS806E85BcrLQ8t8zZrQv3/GGWEs/mWXhS6Ve+8NIZwb0ueeC2+91fDY90GD4KCD4NJLQ32N7STu3j3sZygl98yQ05/9rLTrFpFEi1/Q13f+mHy0awfnnQe33RZ20F5xRZiXHYqjR9f/2A4d4OabG3+OfffNfEvYkwMOgGefzavsvK1eDe+9Fy7vv7/7vgERkQbEr+smff6YQoMeQpfKjh1h52o+I3aaS/fuIYzzGZ+fr/R+B4BFi0q3XhFJvGhb9OecU3f6a1/LdJ80JegPOwyOOQb+8Y+Gd5a2hPQQy3ffDb+GlW3OnDCsc0+6dYMbb6w76ujll0OXk3sI/eOOK23NxXKHG26Ak08Op5RuzKpVoZttwoT8R1fla+tW+J//gR/+MHTn3XVX2JFeqm12ww1h5396/03aO++Ev2f79jBdVgbXXguHH57/uu+/P+xn2muv8O2xT5/CaquthYkT4cc/hs98Jow0GzgQjj22sPVIacyYAXPnFr+eq6+Go45q8sOjC3p3mDcvM71pUxivfvbZYZx7+kfACzV+fDirZe6BTC0p/SH15pt1g37nTvje98I/Y/b5cbJt3w5VVSGURozIzF+4MHyQbd4cQj9uFiyAn/wk/Fh6Pj+p+Mtfwu9+B0cfDaNGlbaWP/wBbropdN1dfjl8//th273yyp53oudj4UK47jr4+9/hiSfq3nfLLTB9eiacq6rCUNzZs/Nbd21teH+0axde544d4Te/Kay+++6DyZPD6b2/8x0YOzaMLEs3FKTlbN0aXs+ysuK7WrOP7G8Kd4/kctRRR3kd//ynO7i3b+8+YIC3ahs3hr/lxhvrzn/88TD/4Ycbfuz27e7durmPGlV3fo8e7uec437KKe6HHVb6mot16aXhb2vTxr2qas/L1tS4d+oUlh8xovS1HH10WHePHu6TJoXb4D5vXvHrvvzysC4z97ffzsz/8EP3rl3dv/WtzLyrrnLfay/36ur81j1tWlj388+7f/vb7l26uG/dWlh9gwaFdRx0kPvPf5752ysrC1uPFG/69LDtn3mmpKsFKr3AvI1P0O/aFQIe3L/5zVJul2j07l33n97d/bTT3MvL3bdt2/Njr73WvazMfe3aML1+fdgut9wS/nnN3Gtrm6Pqpqmtde/c2f2rXw11Tpq05+XTgXbcceGDYc2a0tWyeHFm3eC+997uQ4a4d+jgfvHFxa37gw/c993Xfdiw8BpMmJC5b+bM8Hx//Wtm3quvhnm33prf+o85xr1///C/8PTT4bF/+EP+9S1YsPvfPnRouL7ssvzXI6UxbJj75z8fXs8SakrQx2fUjVnYmXrNNU3rn4+bwYPrdrGsXx/O4XPVVeGr+Z5ceGHoevjJT8IPj6eHag4eHE6u5g6LF4f9EUuXhq6u3L5c93Cg13HH7d4H/t57Yaf30Ufv/tzz5oVz+BRi0aKwb2XixNBHfffd0LNnw8vffjv075/pO7/uutL1nz/2WNi+M2aELovq6jAC6+mnw2kvjjmm6V0YixeH7T9hAvzv/4bhrumuuTvuCK/B8cdnlh84EIYMgTvvbPyr+5YtYd/S5MmhvmHDwkGDv/pV5mR6jXnkEdh773D094AB8O9/w5VXhq60GTNCLeq+aRm1taF775e/jMc2L/SToVSX3Vr07u4bNrjvt5/7739fos++CE2cGFpVmzeH6ZtuCtNLl+b3+BNOyHztBveOHd03bQqtfAjr//hj9wMPdK9vWz7xRFju9tt3v2/MmNClsH593fnvv+++zz51nzffy8CBoeXy0EP5LZ+u6+tfb9rz7ely7rlh3T/7mfsBB4Tuj/nzQyu82HUPGBD+zkce2f2+G27YfVvfe2/+6+7Yse5rMnly4fVdeGF47HXXha6rDz90f+GF0m9jXRq/7LNPab+tptCEFr2Fx7W8iooKr6ys3P2ODz8MrZI4fAoW4/HHw87U556DL3857Azcf/+wszIf27bVPbp2v/3CBeCEE8L5gG67LbMjc+HCukNKzzgj/MziF74QWqLp7blxIxx4YNjpe/PNmTNyQmiVfu97YSdj//6F/b2f/WzmlNJr12ZGntSnrCy0+M3CcmvXFvZcjenZM3yL2bUr7AxN17VhQ9hBVozsv3PduswQ2jZtwqk4ct+37uHvy6dV3qVL3Za/ezh+YufO/OtL/+07d4Ztu88+YX4p/nYpTO7rWSJmtsDdKwp6TOyCPinWrQuB+utfh2FRX/kK3HNPfqdebszMmeEUzL17h3/eLVvC6RnSIzSqq6FHj3DgVlUVvPhippvm9tvD1/levcLIjCVLMuF01FEhIDRCQyS2mhL08TtgKim6dw9B+49/hJZy587hbJmlcNppoaXw9tvhg+P008PZN1esCMF+552hBTl7dmjR/fa3YX5VVeg/r6gIY7SXLg1jfKuqwmmeFy5s/Dw+ItL6FNrXU6pLvX30STNiRKa/7rvfLe26r7girPeNNzIjNLIvQ4eG5c4/f/f77rij7hDH9KV9+zA0VERii0T00SfJypVhtEebNjByZDhSsVS2bAldLMOGhZh+9NEwyiLtuOPCqI3qavjLX0J/NYT9H2edFY68nDcvdN2kHXpo2J8gIrHVbH30ZjYcuA0oA+529xty7j8W+DVwBDDa3R9qbJ2fiqAXESmxZumjN7MyYApwEjAAGGNmA3IWewe4AJhRyJOLiEjzy+eAqSHAcndfAWBms4BRwOvpBdx9Veq+Xc1Qo4iIFCGfUTc9gKqs6dWpeQUzs0vMrNLMKqurq5uyChERKVCLDq9096nuXuHuFeXl5S351CIin1r5BP0aoFfWdM/UPBERaQXyCfqXgH5m1tfM2gGjgTnNW5aIiJRKo0Hv7juAscCTwFLgQXdfYmYTzWwkgJkdbWargTOBu8xsScNrFBGRlpTXaYrdfS4wN2fe+KzbLxG6dEREJGZ0rhsRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknCRBb079OwZLoMGwZYtUVUiIpJskQW9GQwfDl/8IixeDEv044MiIs0i0q6bu++Gm28Ot1eujLISEZHkiryPvnfvcL1qVaRliIgkVuRB36EDdOumFr2ISHOJPOgB+vRRi15EpLko6EVEEi4WQd+3L7z9NuzcGXUlIiLJE4ug79MHPv4Y1q2LuhIRkeSJRdD37RuuV64MB1KJiEjpxCLo+/QJ1y++CAceCDNnRlqOiEiixCLoDzooXE+aBO++C888E209IiJJEoug33vv0JLfvDlM63QIIiKlE4ugh9B9U1YGJ54Ygl599SIipdE26gLSrrgCzjoL2raFp56CtWuhR4+oqxIRaf1iE/SjR4frZ58N10uWKOhFREohNl03aYcfHq7VTy8iUhp5Bb2ZDTezZWa23MzG1XN/ezObnbp/vpn1aWpB5eXhoqAXESmNRoPezMqAKcBJwABgjJkNyFnsIuB9d/888CvgxmKKOvxweO21YtYgIiJp+fTRDwGWu/sKADObBYwCXs9aZhQwIXX7IeA3ZmbuTRs7c/jhMH06zJ/flEeLiCTLIYdA165Nf3w+Qd8DqMqaXg18qaFl3H2HmW0GPgO815SijjwSpkyBoUOb8mgRkWR59FEYNarpj2/RUTdmdglwCcBB6cNh63HeeWFc/fbtLVSYiEiMDR5c3OPzCfo1QK+s6Z6pefUts9rM2gJdgI25K3L3qcBUgIqKiga7dfbaC044IY/KRESkUfmMunkJ6Gdmfc2sHTAamJOzzBzg/NTtM4Cnm9o/LyIipdVoiz7V5z4WeBIoA+5x9yVmNhGodPc5wDTgfjNbDvyb8GEgIiIxYFE1vM2sBlgWyZMXZn+auFO5BanG0mkNdbaGGqF11NkaaoS6dfZ29/JCHhzlKRCWuXtFhM+fFzOrjHudqrF0WkOdraFGaB11toYaofg6Y3cKBBERKS0FvYhIwkUZ9FMjfO5CtIY6VWPptIY6W0ON0DrqbA01QpF1RrYzVkREWoa6bkREEk5BLyKScJEEfWPnt4+CmfUys2fM7HUzW2JmV6bmTzCzNWa2KHU5OeI6V5nZq6laKlPz/sPM/mpmb6WuizjPXUlqPDRrey0ysy1mdlUctqWZ3WNmG8zstax59W4/C25PvU9fMbMizzhSVI2TzeyNVB1/MrP9UvP7mNmHWdv0zghrbPD1NbPrUttxmZl9oyVq3EOds7NqXGVmi1Lzo9qWDWVP6d6X7t6iF8LRtf8CDgbaAYuBAS1dRz11dQcGp253Bt4knH9/AvDDqOvLqnMVsH/OvJuAcanb44Abo64z5/V+F+gdh20JHAsMBl5rbPsBJwNPAAYMBeZHWOPXgbap2zdm1dgne7mIt2O9r2/q/2gx0B7om/r/L4uqzpz7bwHGR7wtG8qekr0vo2jRf3J+e3ffDqTPbx8pd1/n7gtTt2uApYTTL7cGo4D7UrfvA06NsJZcJwD/cve3oy4EwN2fI5ymI1tD228UMN2DecB+ZtY9ihrd/f/cfUdqch7h5IKRaWA7NmQUMMvdt7n7SmA5IQea3Z7qNDMDzgJmtkQtDdlD9pTsfRlF0Nd3fvtYBaqFn0I8Ekj/9MnY1Feke6LuFgEc+D8zW2DhtM8A3dx9Xer2u0C3aEqr12jq/iPFaVumNbT94vpevZDQokvra2Yvm9nfzewrURWVUt/rG9ft+BVgvbu/lTUv0m2Zkz0le19qZ2wOM+sEPAxc5e5bgDuAzwGDgHWEr3pR+rK7Dyb8tOP3zezY7Ds9fLeLxZhZC2c7HQn8MTUrbttyN3HafvUxs+uBHcADqVnrgIPc/UjgB8AMM9s3ovJi//rmGEPdRkik27Ke7PlEse/LKII+n/PbR8LM9iJs6Afc/REAd1/v7jvdfRfwO1roK2dD3H1N6noD8KdUPevTX91S1xuiq7COk4CF7r4e4rctszS0/WL1XjWzC4ARwNmpf3xS3SEbU7cXEPq/D4mivj28vrHajgAWfjfjv4DZ6XlRbsv6socSvi+jCPp8zm/f4lL9ddOApe5+a9b87L6v04DIfrbczDqaWef0bcIOuteo+3sA5wOPRVPhbuq0mOK0LXM0tP3mAOelRjkMBTZnfZVuUWY2HLgWGOnuW7Pml5tZWer2wUA/YEVENTb0+s4BRptZezPrS6jxxZauL8eJwBvuvjo9I6pt2VD2UMr3ZUvvYc7aa/wm4RPz+ihqqKemLxO+Gr0CLEpdTgbuB15NzZ8DdI+wxoMJoxcWA0vS247w+7x/A94CngL+IwbbsyPhV8a6ZM2LfFsSPnjWAR8T+jYvamj7EUY1TEm9T18FKiKscTmhXzb93rwztezpqffCImAhcEqENTb4+gLXp7bjMuCkKF/v1Px7gctylo1qWzaUPSV7X+oUCCIiCaedsSIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gk3P8DFWU4QBJby70AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "ax = plt.subplot(111)\n", "ax.set_ylim(-0.01, 0.5)\n", "ax.set_xlim(0, T+2)\n", "ax.plot(np.arange(T)+1, err_tr, color='blue')\n", "ax.plot(np.arange(T)+1, err_ts, color='red')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TO DO:\n", "Try other base learners: e.g. a decision tree with 2 levels, in the above example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RealAdaboost" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAft0lEQVR4nO3df5xUdb3H8dcHEFFARSRC+bUqpZip617x5o8UsYQUSrRAsiwIuw8xf9QtTCSlfFx/ZffhFTMCS32o+KMSSrxaYKR5RX6KIJIgGCAqJOIvBNb93D8+M87s7uzu7O7szuzx/Xw85rFzzpwz5zNnzr7nzPd8zxlzd0REJLnaFbsAERFpWQp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJuLyC3szOMLPVZrbGzCbmePwCM9tiZstSt3GFL1VERJqiQ0MTmFl7YCpwOrARWGhms939hRqT3u/uE1qgRhERaYZ89uiPA9a4+8vuvguYCYxo2bJERKRQGtyjBw4CNmQNbwQG5ZhupJmdDPwDuMzdN9ScwMzGA+MBOnfufOxhhx3W+IpFRD7GFi9evNXdezRmnnyCPh9/BO5z951mdiFwJzC45kTuPg2YBlBRUeGLFi0q0OJFRD4ezOyVxs6TT9PNJqBP1nDv1LiPuPu/3H1nanA6cGxjCxERkZaRT9AvBAaYWZmZdQRGAbOzJzCzXlmDw4FVhStRRESao8GmG3evNLMJwGNAe+AOd19pZlOARe4+G/iemQ0HKoE3gQtasGYREWkEK9ZlitVGLyLSeGa22N0rGjOPzowVEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjC5RX0ZnaGma02szVmNrGe6UaamZtZReFKFBGR5mgw6M2sPTAVGAoMBEab2cAc03UFLgEWFLpIERFpunz26I8D1rj7y+6+C5gJjMgx3U+B64EPClifiIg0Uz5BfxCwIWt4Y2rcR8ysHOjj7o/U90RmNt7MFpnZoi1btjS6WBERabxmH4w1s3bAzcD3G5rW3ae5e4W7V/To0aO5ixYRkTzkE/SbgD5Zw71T49K6Ap8B/mpm64Hjgdk6ICsiUhryCfqFwAAzKzOzjsAoYHb6QXff7u4HuHt/d+8PPAMMd/dFLVKxiIg0SoNB7+6VwATgMWAV8IC7rzSzKWY2vKULFBGR5umQz0TuPgeYU2Pc5DqmPaX5ZYmISKHozFgRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEyyvozewMM1ttZmvMbGKOx79rZs+b2TIze8rMBha+VBERaYoGg97M2gNTgaHAQGB0jiC/192PdPejgRuAmwteqYiINEk+e/THAWvc/WV33wXMBEZkT+Dub2cNdga8cCWKiEhzdMhjmoOADVnDG4FBNScys4uAy4GOwOCCVCciIs1WsIOx7j7V3Q8BfgRMyjWNmY03s0VmtmjLli2FWrSIiNQjn6DfBPTJGu6dGleXmcCXcz3g7tPcvcLdK3r06JF/lSIi0mT5BP1CYICZlZlZR2AUMDt7AjMbkDX4JeClwpUoIiLN0WAbvbtXmtkE4DGgPXCHu680synAInefDUwwsyHAbmAb8M2WLFpERPKXz8FY3H0OMKfGuMlZ9y8pcF0iIlIgOjNWRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSTgFvYhIwinoRUQSTkEvIpJwCnoRkYRT0IuIJFzbCfoXXoALL4TKymJXIiLSprSdoL/7bpg2DdasKXYlIiJtSl5Bb2ZnmNlqM1tjZhNzPH65mb1gZsvNbK6Z9St4pUuWxN916wr+1CIiSdZg0JtZe2AqMBQYCIw2s4E1JlsKVLj7Z4GHgBsKWqW7gl5EpIny2aM/Dljj7i+7+y5gJjAiewJ3f8Ld308NPgP0LmiVGzfC1q1xX0EvItIo+QT9QcCGrOGNqXF1GQs8musBMxtvZovMbNGWLVvyrzK9N9+unYJeRKSRCnow1sy+DlQAN+Z63N2nuXuFu1f06NEj/ydeuhTM4KST4OWXC1OsiMjHRD5BvwnokzXcOzWuGjMbAlwJDHf3nYUpL2XJEjjsMDjiCO3Ri4g0Uj5BvxAYYGZlZtYRGAXMzp7AzI4BfkWE/BsFr3LJEigvh7IyeOutuImISF4aDHp3rwQmAI8Bq4AH3H2lmU0xs+GpyW4EugAPmtkyM5tdx9M13rp1sGkTHHtsBH16nIiI5KVDPhO5+xxgTo1xk7PuDylwXRm/+U20z59zTvWeN8cc02KLFBFJkryCvmg+/DCC/otfhD59oGvXGK8DsiIieSvtSyA8/nj0oR83Lob32y9uaroREclbaQf9XXdBjx5w1lmZcQcfDLffHoE/d27xahMRaSNKO+iXL4cTToCOHTPjbrgBLr8cPvgA5sype14REQFKOejdYf36TE+btNNOgxtvhKOOypwxKyIidSrdoH/jDXj//WiqyaW8PM6YdW/dukRE2pjSDfr0Adeae/RpxxwD27fHdDNmwIMPtl5tIiJtSNsN+vLy+PvEE3DxxXDtta1Tl4hIG1O6QZ/uK9+/f+7HP/MZ6NABrr4aduyAlSthZ2EvsSMikgSlG/Tr1kHPnrD33rkf79QpLnK2cSO0bx+/JbtiRevWKCLSBpR20NfVbJOWvgzCxRfHX/XCERGppbSDvq4eN2lf+EKcUDVpEuy7r4JeRCSH0gz6ykr45z8b3qMfPRpeew26d4+Dswp6EZFaSjPoN2yIC5o1FPQQPy8I0YyzfHl8SIiIyEdKL+hffx3mz4/7+QR9Wnl5XBbhxRdj+N13Y1hE5GOutC5T/M470S7//vsxPGBA/vMee2z8feqp6I1z0knRNfMPfyh4mSIibUlpBf2yZRHyV10V17Tp06fhedI+/WkYOBB++1v47GfjuZYvj+6XvXu3WMkiIqWutJpu0gdT/+M/4POfb9y8ZnHd+gUL4D//E/baC6qqIvhFRD7GSi/oP/lJ6NWrafOffz7ssQc8/TSMGRPfCmbMiMAXEfmYKr2gT1/DpikOOAC+/OW4P24cjB0blzr+298KUl6L2rkTBg/OHIgWaY4ZM+DAA2OnqVevOPaV7qjQkiZNgksuiftz50bza7qG+m7HHVd354nNm6GiInf36Zkz4cwztTPXgNJpo9+xA1atygR1U117LRx/fGw4AwfGuL//HU45pdkltqjly+MCbYMGNb7ZSiRbVRX8139Bly5w6qkx7s474bbb4JZbWm6527bBTTfB7t3wwx/G/Z074StfqX++7dvh/vuj48To0bUfnzEDFi+GX/wC7r67+mP33w+PPALz5sGQIYV7LUnj7kW5HXvssV7NggXu4P6733lBDRjgfvbZhX3OlnD77fH6v/a1Ylcibd28ebEt3X13ZtyoUe7durnv2NFyy/2f/4nlgvuFF7qbuV91VcPzffihe1mZ++DBuR/r3z+es1Mn923bqj/et288NmpUYV5DGwAs8kbmbek03aS/ljWn6SaXtnLGbLpG/fC5NNeMGXFJkJEjM+PGjYs97pbqbuwO06fHiYunnAK/+lWM//a3G563XbuYbt68zFVr0554Ippfv//9aNq5997MY1u3xhn0++wDv/89/OtfhXo1iVPcppuvfz1zf8kS6NYN+vUr7DLKy+Pr3Ztvwv77RxPRlCnwgx/EpRNymTYNDj88+uL//e9RW/rCaTVt3Ro/bTh5MnTuDD/9KaxeXX2adu3id26PPrr6+E2b4ofOJ0+OX8sCBX1rmTo1uuGedFLLLmfjxgi9yZNjO7jiCnj11eh0cP31EZA/+lGcKFifr34Vhg+PJr4//SmeZ8eO6GG2fXv1aR96KI5P7bVXZtypp8YJiDNm5G4eqcvmzdHu3tAlwHfuhOeeg1tvjQ+Zv/41OkPUdZnxmi64AH7yEzjvPDj00Mz4pUthv/3i/+ovf4kmqaefjo4X6bPir746/r/uuQe+973MvDffXNo7eXvsEbX36xfvyxNP1D3tZZdlzhVqguIFvTs880z1cePGRTfJQkp/Q1i6NDa8e+6B666LDfDCC2tPv3t3hPoRR8RGctllsHBhHPDJdaburbfGD5aXlUX7+uTJcXAp+/LKmzbBW2/B7NnV5737bvjZz2JZy5fDnnvCli1xVm+XLgVbBVLD+vXxHpeXw6JFLbusu+6K97i8PLaJG2+My2+//nocfK+sjEDq2zf+8XPZujU6FHzpS/DjH0eb9OmnR7Dedltse+2yvpwfemjtHZP0XvNVV8Vec0MXDEy75ZboopzPWeqDBkVvtz33jG174sT8lgFxrsuECfHatm6t/tiVV8aH1qRJ8fpnz47/yW99Kx7/5jfhvvviG8XFF0eGrFsX3wJ69izd/6X166FrV7jmmnjtXbrEh2Qu27Y1b1mNbesp1K1WG31L2bo12vBuuCGGBw2K4fHjc0+/bFmmnfE3v8ncnzSp9rSVle59+sTjxx7rftFF7nvu6f7mm9Wnu+IK93bt3DdurD7+3HNj3rKy+Dt8ePxdvrzZL1vqMXly5n1dsqRllzVyZCxn2DD3c85xP+AA93ffdf/EJ9y/8hX3s85y79XLfffuup/jwQfjOaZPj+0ovf0OGuQ+cKB7VVV+tWzYEPPn2pZz2bXL/ZOfjO2ylKT/L8vKov3ePXOM69lnY3jSpDhG8M9/Fq3MBqWPm9x0U9S+eHFes9GENvrkB717HLAZPdr9+efjJZu5V1TknvaOOzLTdOrk3rGj++c+537QQRHs2R59NKY99dTMwaLzzqv9nC+9FI9fe2318YccEstJh86dd8bfWbMK87qltvSH87//e3woX3RRyy6vrCze43bt3PfYw/3yy2P8D37g3qFDjL/iivqfY+fO+IDo1Cm2j1NOydy/+ebG1TNsWO5tOZeHHy7N7fHdd927do3a0h0ttm9333vv+ACsrHTv3dt96NDi1tmQv/wlkxvHHJP3bE0J+tLpXtmSysvjGjiTJ8fX4zFj4qve7t21vy4vXRpfoYYPjwM/X/sanHsunHNOtKcdckhm2jvvjL77990X7WwffBBtozUdemh0mfz1r6Nvc3l5TL92LXznO/GVc++94YtfjOnXrYs+zzWbtmrq0yeaoyAOZJ10Uub1uMPDD0f7bVlZ47psVlVFe+iQIdWbBJrr5Zdh1y447DB47704MFhZGV/5Dz88Ln+xcGHhupe+8QY8+misi7R16+LqqD//eayfe+6Jts/DD49uuWmbNsGf/5wZ7toVzj67/qbFXbvgyScz78m2bbG873wn3vuqqsz2MXZsdD+Ehg9YduwI3/hGNPEMHhzb4cknx3t9/vl5r46PljtyZO1tOZfp0+NYwrBhjVtGS+vcOY4zTJuWaZrdZ5/4P73vPvjEJ+LYyH//d3HrbEj6uMm6dblzo5Aa+8lQqFur7tHffHNmr3nMGPd77437zz1Xe9oTTnA/8UT3p5+OPa7582OPqnfvzHNk3yZOjPnGjnU/4ojoDpbL/fdn5unVK/Np/uij7iNGxJ5WVZV7ly7uF18c3UJzLS/7Zua+ZUs09WQ3T7lnvm2Ae/v2tZuN6nPXXTHfffflP09DqqrcP/vZ2JuurHT/yU8y9R1ySKy3iRNjeNmywixzzJjc6+3AA90/+MD9yScz4/be2/2ttzLzjh5de74//rH+5V13XUw3f34Mp7s5PvaY+5ln1u4+OGRI/nudq1bFt8vf/z7W5dFHu3/jG/mvi7T6tuVct8mTG7+M1rBkSXxDeuqpzLhnnsnUfdBB8VpL3U03RfNNzebeeqCmmzpUVUVb3bp10Rb64ov+URt8tspK986dI2jd4yti2jvvxPzZt/XrM8G+e3fDfZQ3bco0DQ0eHH9fey3aQnftimmOPDLaRcH9lltqLzN9mzkzpnn8cfcZM+L+pz6Vaa8dOTK+7v/tb56z2ag+J58c85x2Wv7zNOTZZzP/hI88EoF/2mnut96aeR3p151e/83x5pvRNPOtb9Ved2+/nZnutdfc58yJ5d5+e2b8oYfGh++6de5r17r37BkfyHWpqsp8OJ9/foxLt72+8UaETs3gyTWuPtnb444d9bfr1yfXtpzr9sorde+4lIL33qs97vXXa7/Hpayqqvr7mgcFfb4+/DCz55xt1arcHwCFtGtXhEZ6r6Om9AHZffbJvSGnvflmTHfdddHOnA7RJ5+MYNljD/fLLotpP//5zF5zQ1avjudJH2Reu7ZJL7OW8ePd99rLvXt393794rkfeMD9/ffd9903M65Pn8Kc2JP+AMnnAFdVVXzA/tu/xfD27THvz36WmeaHP4xvRq++mvs55s/P1J8+see882JYpICaEvSlc8JUa2rXLvq0P/tstNemb/PmxeOFPmkr2x57RHewupaT7sZ23nnVu2jW1K1bTLt0adzKy6Md+bbbom/+7t2Zdr9x4+J4wKxZ1V9v+pe8st1xB7RvD7/7XaynqVOjvTN9LZEPP4z5Nm7MtH2nx9V1W7s22k7PPTfaml95JY5tDB8e3ebGjIlxvXpFu+u2bXH847XXMnXt3Fn/Mmre0ifv5PNepq98unBhdHNdtqz2+zN2bLzOX/4y9/J++ctoJ77nnjhWM21adN1syW1JJF+N/WQo1K2oe/Tu7pdemtkLzr7ttVemGaWlpPear7mm9mNTp8ZjCxc2/Dxnn+1+8MHRvnzJJbHXnH4dxx+fme7999332y/36/3udzPT1exON2xYZrof/SjGjRuXGZduvz3//NzPXfM2f777ihVxP937xD3aWyF6n6RPh0/Pc9ddMc3nPpffMrJvt96a/3uydWu0gX/ve+6/+EXMv3lz9WnSTVp13S68MNN+nh43ZUr+NYjkgSbs0Zun98paWUVFhS9q6ZNV6rNlS5xhWPOqd5/+NJx4Yssv/6mn4szMffapPv6992LPMp+LsF17bZxEAnFSy5lnxskkVVVxRD/7pJj03mq2hx6KXiKbN8e3gVmz4qJys2bF3vbGjfDYY3Hyy4oV8MILcaLZ4MFxAtjatfFtol8/GDoUzjqr7lq7d4cRI2Lv+a9/jZ4uXbtmHp83L3rfdO4cy1qwIM4c7d499paPOSZ6rwwa1PB6gThp56tfjR4r+Ro9Ol7v6afHenn11eqPr18fV2TMpV27eH377w8vvRQnOHXoEOP22y//GkQaYGaL3b2iUTPl82kAnAGsBtYAE3M8fjKwBKgEzsnnOYu+R58E6YOI0LSTrP7v/2LeX/86hs86K/boax7k+9//9Y8Ozqa/bcyaVX1cS5zkle4tddppuU9EK7Q//zmW166d+5e+1LLLEmkiWqKN3szaA1OBocBAYLSZDawx2T+BC4B7kdaTbv/t1Cn6gTfWoEFxKefp02Pv9ZFH4pojHWqcXjFkSJyiP3cuHHVU7I0PGxZt6nPnxiWhjzyy2S+nlvQPycydG32/u3Ur/DKyDR4c31iqqtS2LomSz8HY44A17v6yu+8CZgIjsidw9/XuvhzQ1f9bU8+eEbZHHlk7nPNhFgcZFyyIH3aoqsp98k779pnriowdG/N16JA5qNxSJ3tk/5BMS59QApnrwUA0FYkkRD7pcBCwIWt4I5BnQ2l1ZjYeGA/Qt2/fpjyF1PTznzevDXjs2PjBl3ffjb31AQNyTzdhArzzTuzxp116afTuyb4KaaFNmRJn0rbWD8dcdBG8/TZ84QutszyRVtDgwVgzOwc4w93HpYbPBwa5+4Qc0/4W+JO7P9TQgot+MFZEpA1qysHYfJpuNgF9soZ7p8aJiEgbkE/QLwQGmFmZmXUERgGzG5hHRERKRINB7+6VwATgMWAV8IC7rzSzKWY2HMDM/s3MNgLnAr8ys5UtWbSIiOQvr64a7j4HmFNj3OSs+wuJJh0RESkxH89r3YiIfIwo6EVEEk5BLyKScAp6EZGEU9CLiCScgl5EJOEU9CIiCaegFxFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThFPQiIgmnoBcRSbiiBb079O4dt6OPhrffLlYlIiLJVrSgN4MzzoCjjoLnnoOV+vFBEZEWUdSmm+nT4aab4v66dcWsREQkuYreRt+vX/xdv76oZYiIJFbRg37vvaFnTwW9iEhLKXrQA/Tvr6YbEZGWUjJBrz16EZGWURJBX1YGr7wCVVXFrkREJHlKIuj794fdu+HVV4tdiYhI8pRE0JeVxV8134iIFF5JBH3//vFXQS8iUnglEfR9+8Zf9bwRESm8kgj6Tp3gwAO1Ry8i0hJKIughmm/WrCl2FSIiyVMyQX/KKfDkk7BqVbErERFJlpIJ+ssug86d4Zpril2JiEiylEzQH3AAXHwxPPAArFhR7GpERJIjr6A3szPMbLWZrTGziTke39PM7k89vsDM+jelmO9/H7p00V69iEghNRj0ZtYemAoMBQYCo81sYI3JxgLb3P1Q4BfA9U0ppnt3uOQSeOghWL68Kc8gIiI1dchjmuOANe7+MoCZzQRGAC9kTTMCuDp1/yHgVjMzd/fGFnT55XDLLXDFFTB5cmPnFhFJnk99Crp1a/r8+QT9QcCGrOGNwKC6pnH3SjPbDnQHtja2oG7d4sDsNdfAnDmNnVtEJHkefhhGjGj6/PkEfcGY2XhgPEDf9OmwOfz4x3DCCVBZ2VqViYiUrvLy5s2fT9BvAvpkDfdOjcs1zUYz6wDsC/yr5hO5+zRgGkBFRUWdzTodO8Lpp+dRmYiINCifXjcLgQFmVmZmHYFRwOwa08wGvpm6fw4wrynt8yIiUngN7tGn2twnAI8B7YE73H2lmU0BFrn7bGAGcLeZrQHeJD4MRESkBFixdrzN7B1gdVEW3jgH0ISDyq1MNRZOW6izLdQIbaPOtlAjVK+zn7v3aMzMrXowtobV7l5RxOXnxcwWlXqdqrFw2kKdbaFGaBt1toUaofl1lswlEEREpGUo6EVEEq6YQT+tiMtujLZQp2osnLZQZ1uoEdpGnW2hRmhmnUU7GCsiIq1DTTciIgmnoBcRSbiiBH1D17cvBjPrY2ZPmNkLZrbSzC5Jjb/azDaZ2bLUbViR61xvZs+nalmUGre/mf3ZzF5K/W3Gde4KUuOns9bXMjN728wuLYV1aWZ3mNkbZrYia1zO9WfhltR2utzMmnnFkWbVeKOZvZiq4w9mtl9qfH8z25G1Tm8vYo11vr9mdkVqPa42sy+2Ro311Hl/Vo3rzWxZanyx1mVd2VO47dLdW/VGnF27FjgY6Ag8Bwxs7Tpy1NULKE/d7wr8g7j+/tXAD4pdX1ad64EDaoy7AZiYuj8RuL7YddZ4v18D+pXCugROBsqBFQ2tP2AY8ChgwPHAgiLW+AWgQ+r+9Vk19s+ersjrMef7m/o/eg7YEyhL/f+3L1adNR7/OTC5yOuyruwp2HZZjD36j65v7+67gPT17YvK3Te7+5LU/XeAVcTll9uCEcCdqft3Al8uYi01nQasdfdXil0IgLv/jbhMR7a61t8I4C4PzwD7mVmvYtTo7o+7e/p6rs8QFxcsmjrWY11GADPdfae7rwPWEDnQ4uqr08wM+CpwX2vUUpd6sqdg22Uxgj7X9e1LKlAtfgrxGGBBatSE1FekO4rdLAI48LiZLba47DNAT3ffnLr/GtCzOKXlNIrq/0iltC7T6lp/pbqtfpvYo0srM7OlZjbfzE4qVlEpud7fUl2PJwGvu/tLWeOKui5rZE/BtksdjK3BzLoAvwMudfe3gV8ChwBHA5uJr3rFdKK7lxM/7XiRmZ2c/aDHd7uS6DNrcbXT4cCDqVGlti5rKaX1l4uZXQlUAvekRm0G+rr7McDlwL1mtk+Ryiv597eG0VTfCSnqusyRPR9p7nZZjKDP5/r2RWFmexAr+h53/z2Au7/u7h+6exXwa1rpK2dd3H1T6u8bwB9S9bye/uqW+vtG8SqsZiiwxN1fh9Jbl1nqWn8lta2a2QXAmcCY1D8+qeaQf6XuLybavz9VjPrqeX9Laj0CWPxuxtnA/elxxVyXubKHAm6XxQj6fK5v3+pS7XUzgFXufnPW+Oy2r68AK2rO21rMrLOZdU3fJw7QraD67wF8E5hVnAprqbbHVErrsoa61t9s4BupXg7HA9uzvkq3KjM7A/ghMNzd388a38PM2qfuHwwMAF4uUo11vb+zgVFmtqeZlRE1Ptva9dUwBHjR3TemRxRrXdaVPRRyu2ztI8xZR43/QXxiXlmMGnLUdCLx1Wg5sCx1GwbcDTyfGj8b6FXEGg8mei88B6xMrzvi93nnAi8BfwH2L4H12Zn4lbF9s8YVfV0SHzybgd1E2+bYutYf0athamo7fR6oKGKNa4h22fS2eXtq2pGpbWEZsAQ4q4g11vn+Alem1uNqYGgx3+/U+N8C360xbbHWZV3ZU7DtUpdAEBFJOB2MFRFJOAW9iEjCKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCTh/h8cAkKdgs7rsAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## Weak learner: decision stumps\n", "\n", "T=200\n", "bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1), n_estimators=T,\n", " algorithm='SAMME.R')\n", "\n", "# fit the model (as usual)\n", "bdt.fit(Xtr, Ytr[:,0])\n", "\n", "# The result of AdaBoost with decision stumps can be analyzed\n", "# to find the most important variables from the data set:\n", "\n", "np.where(bdt.feature_importances_ > 0.01)\n", "\n", "# gives the indexes of variables with importance score \n", "# higher than a threshold (0.01)\n", "\n", "# Get the errors, per step:\n", "# - train error\n", "err_tr = np.zeros((T,)) \n", "for i, yp in enumerate(bdt.staged_predict(Xtr)):\n", " err_tr[i] = zero_one_loss(yp, Ytr[:,0])\n", "\n", "# - test error\n", "err_ts = np.zeros((T,)) \n", "for i, yp in enumerate(bdt.staged_predict(Xts)):\n", " err_ts[i] = zero_one_loss(yp, Yts[:,0])\n", "\n", "fig = plt.figure()\n", "ax = plt.subplot(111)\n", "ax.set_ylim(-0.01, 0.5)\n", "ax.set_xlim(0, T+2)\n", "ax.plot(np.arange(T)+1, err_tr, color='blue')\n", "ax.plot(np.arange(T)+1, err_ts, color='red')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RadomForest" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10,\n", " n_jobs=None, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = RandomForestClassifier(n_estimators=10)\n", "clf.fit(Xtr, Ytr[:,0])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zero_one_loss(clf.predict(Xtr), Ytr[:,0]) # train error" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.10999999999999999" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zero_one_loss(clf.predict(Xts), Yts[:,0]) # test error" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=2, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10,\n", " n_jobs=None, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Other parameters\n", "clf = RandomForestClassifier(n_estimators=10, max_depth=2)\n", "clf.fit(Xtr, Ytr[:,0])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train error 0.02308\tTest error: 0.14000\n" ] } ], "source": [ "print(\"Train error {:1.5f}\\tTest error: {:1.5f}\".format(zero_one_loss(clf.predict(Xtr), Ytr[:,0]), zero_one_loss(clf.predict(Xts), Yts[:,0])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TO DO:\n", "* What can you say about error rate on the test set in the 2nd example (with respect to 1st example)?\n", "* Try other parameter combinations..." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 1 }