You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1046 lines
429 KiB

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"___\n",
"\n",
"<a href='http://www.pieriandata.com'><img src='../Pierian_Data_Logo.png'/></a>\n",
"___\n",
"<center><em>Copyright by Pierian Data Inc.</em></center>\n",
"<center><em>For more information, visit us at <a href='http://www.pieriandata.com'>www.pieriandata.com</a></em></center>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# KNN - K Nearest Neighbors - Classification\n",
"\n",
"To understand KNN for classification, we'll work with a simple dataset representing gene expression levels. Gene expression levels are calculated by the ratio between the expression of the target gene (i.e., the gene of interest) and the expression of one or more reference genes (often household genes). This dataset is synthetic and specifically designed to show some of the strengths and limitations of using KNN for Classification.\n",
"\n",
"\n",
"More info on gene expression: https://www.sciencedirect.com/topics/biochemistry-genetics-and-molecular-biology/gene-expression-level\n",
"\n",
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('../DATA/gene_expression.csv')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Gene One</th>\n",
" <th>Gene Two</th>\n",
" <th>Cancer Present</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>4.3</td>\n",
" <td>3.9</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2.5</td>\n",
" <td>6.3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5.7</td>\n",
" <td>3.9</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6.1</td>\n",
" <td>6.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>3.4</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Gene One Gene Two Cancer Present\n",
"0 4.3 3.9 1\n",
"1 2.5 6.3 0\n",
"2 5.7 3.9 1\n",
"3 6.1 6.2 0\n",
"4 7.4 3.4 1"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:xlabel='Gene One', ylabel='Gene Two'>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3xcZ5m3f51zpveiGfUuW5J7jUvsxOk9IUAgEEroS9nCb5dd4F1gF9gXtsK7LMsuu0CAJLRAOumJY8e9V9nqXRpN7+2U3x9HkePYaaQ5ZC4+fBw9mjnzaDS6z/Pcz/f+3oKmaVSoUKFChbcP4ps9gQoVKlSo8MZSCfwVKlSo8DajEvgrVKhQ4W1GJfBXqFChwtuMSuCvUKFChbcZhjd7Ai+HqqoqraWl5c2eRoUKFSq8pdi3b19E07TA88ffEoG/paWFvXv3vtnTqFChQoW3FIIgjJxtvJLqqVChQoW3GZXAX6FChQpvMyqBv0KFChXeZrwlcvwVKlT446NcLjM+Pk6hUHizp/KWx2Kx0NDQgNFofFmPrwT+ChUqvCmMj4/jdDppaWlBEIQ3ezpvWTRNIxqNMj4+Tmtr68t6zuuW6hEE4ceCIMwIgnD0OWM+QRAeEwShb/Zf7+v1+hXOPTIFmV2DUe47NMGuwSiZgvxmT6nCm0ihUMDv91eC/qtEEAT8fv8r2jm9njn+24Arnzf2ReAJTdPmAU/Mfl3hbUCxrPDLPaP8cMsg9x6Y5IdbBvnV3lFKsvJmT63Cm0gl6L82vNL38XUL/JqmbQFizxu+Afjp7H//FHjH6/X6Fc4thqM5dgxETxvb3h9lKJJ7k2ZUocLblzda1VOtadrU7H9PA9Uv9EBBED4pCMJeQRD2hsPhN2Z2FV43ssXyWcdzpUq6p8Ippqenufnmm2lvb2flypVcffXV9Pb2vtnTYtOmTXR2drJ06VLOP/98Tp48+YbPYXh4mDvvvPM1udabJufU9A4wL9gFRtO0H2qatkrTtFWBwBkVxxXeYjR67djN0mljDrOBBq/1rI9P5cv0hdKEUm+s4kNRNUajOYYimUoa6g1G0zRuvPFGNm3axMDAAPv27eNb3/oWoVDoDZ+HqqpnjN9xxx0cOnSID3/4w3zhC1844/uK8vp+Xt7KgT8kCEItwOy/M2/w61d4kwi4zHzygnYavTYEARq9Nj55YRsBp+WMxx6fTPHNB3v49kMn+Pr9x3j8+DSycuYf4mtNIlfiZzuG+MYDx/jmAz38x5P9TCXzr/vrVtB56qmnMBqN/Mmf/Mnc2NKlS9m4cSOZTIZLLrmEFStWsHjxYu69915AD4bd3d184hOfYOHChVx++eXk8/rvrL+/n0svvZSlS5eyYsUKBgYGAPjnf/5nVq9ezZIlS/ja1742d53Ozk4+9KEPsWjRIsbGxl5wnhdccAH9/f0AOBwO/vIv/5KlS5eyY8cObr/9ds477zyWLVvGpz71KRRFQVEUbr31VhYtWsTixYv5zne+A8DAwABXXnklK1euZOPGjZw4cQKAW2+9lT/7sz9j/fr1tLW1cddddwHwxS9+ka1bt7Js2bK5a/yhvNGB/z7gw7P//WHg3jf49Su8iSyqd/M3V3XyjXcs4m+u6mRhnfuMx6TyZW7bPkw0UwSgUFb55e4xBsLZ131+e4fjPNMXRZ3dhx6bTPHUiZe3NlFUjVi2RLFc2SX8oRw9epSVK1ee9XsWi4W7776b/fv389RTT/GXf/mXPNs2tq+vj89+9rMcO3YMj8fDb3/7WwBuueUWPvvZz3Lo0CG2b99ObW0tjz76KH19fezevZuDBw+yb98+tmzZMnedz3zmMxw7dozm5uYXnOf999/P4sWLAchms6xZs4ZDhw7h9/v51a9+xbZt2zh48CCSJHHHHXdw8OBBJiYmOHr0KEeOHOEjH/kIAJ/85Cf53ve+x759+/iXf/kXPvOZz8y9xtTUFM888wwPPPAAX/yiroH59re/zcaNGzl48CCf//znX9V7/brp+AVB+AWwCagSBGEc+BrwbeDXgiB8DBgB3vN6vX6FcxOryYDV9MIfu3C6OBf0n0UDJhJ5Omucr+vceqZSZ4wdGkvwzhX1WIwvPOeRaJb7D01ycjpNjdvCDcvqWFTveR1n+vZD0zS+/OUvs2XLFkRRZGJiYi4F1NrayrJlywBYuXIlw8PDpNNpJiYmuPHGGwH9xgHw6KOP8uijj7J8+XIAMpkMfX19NDU10dzczNq1a19wDrfccgtWq5WWlha+973vASBJEu9617sAeOKJJ9i3bx+rV68GIJ/PEwwGue666xgcHORP//RPueaaa7j88svJZDJs376dm266ae76xeKpz/073vEORFFkwYIFr0uq63UL/Jqmve8FvnXJ6/WaFd76OC0GrEaJ/PNWzl7by6tIfDXUe60cHEucNtbgtWGSpLM/AcgWZX7yzDBjcV2dNBjO8oOnBvjSNd00eG2v53T/6Fi4cOFcWuP53HHHHYTDYfbt24fRaKSlpWVOt242m+ceJ0nSXKrnbGiaxpe+9CU+9alPnTY+PDyM3W5/0fndcccdrFq16rQxi8WCNPv50DSND3/4w3zrW98647mHDh3ikUce4b/+67/49a9/zXe/+108Hg8HDx4862s992d6dmfzWlLx6qlwThF0WbhmSS2RTIH+mQwTiTyL6l3Mq3a87q+9usVHlcM097XVJHHZgmpE8YU10uOJ3FzQf5aCrDLyGspU+0Jpfrp9iH986ASPHZ8mlT+7QuqtzsUXX0yxWOSHP/zh3Njhw4fZunUryWSSYDCI0WjkqaeeYmTkrG7DczidThoaGrjnnnsAfTWdy+W44oor+PGPf0wmkwFgYmKCmZnX5qjxkksu4a677pq7XiwWY2RkhEgkgqqqvOtd7+Kb3/wm+/fvx+Vy0draym9+8xtAD+6HDh16yZ8pnU6/JnOtWDZUOKeQFZVwusBFndVkijJWk0RJVskWFRzm13fV3+iz8YUrOumbySCrGu0BB3Wes6uOnsUsSUiCgPK8VZnF9NqsqcZiWf79iT5yJX0H1BtKE04Xef+aF85Bv1URBIG7776bv/iLv+Af//EfsVgstLS08N3vfpdbbrmF6667jsWLF7Nq1Sq6urpe8no///nP+dSnPsVXv/pVjEYjv/nNb7j88svp6elh3bp1gH44e/vtt8+t2l8NCxYs4Jvf/CaXX345qqpiNBr5/ve/j9Vq5SMf+cicUujZHcEdd9zBpz/9ab75zW9SLpe5+eabWbp06Qtef8mSJUiSxNKlS7n11ltfVZ5feD22Ea81q1at0iqNWN4ejEZzfOOBY3MHrM/y0Q0tnN9x7sl6FVXjzl0jbD55qtak2mXhku4gJVmls8ZJW+AP36080RPizl2jp40ZJYG/v2ER1a4zFVFvJXp6euju7n6zp/FHw9neT0EQ9mmatur5j62s+CucU2iaxtnWIs+/EZwrSKLAO5bX0xZwMBLNYjFKDM5k5oK1ySDyZ5fMo7vW9Qdd/2w/t6a9PnnfCm8fKjn+CqcxlcyzYyDCgdE46cLLyyUXSgrHJpJs648wFMm8qqBU67GyrMlz2pjDItHxKlbNrzdOi5HzO6p4/5pmppMFeqZP5WFLssq2vsjc12OxHNsHIhwaS7xgNfNzmR90YDGc/me6rqPqLb/ar/DmUlnxV5jj6ESS/9o8MKeomVft5OMbW6lymF/wObmSzK/2jPHMbHCTRIEPrmtm47w/LC1jMoi8d3Uj9R4b+0ZitPjtXNgZoPYlcu3nCuHnSVGfO7Z/JM7/bB2kJOu53qWNbj68rgW3zXTGc56lucrO5y6Zx1MnZphO5Vnd4mdDR8XRssKroxL4KxBOFQlnCuwdjlF+Tql6XyjN4fEkF3cFX/C5vaHMXNAHPef9m71jdNe6XvSG8Xw0TWMikSdXUqhzW7lxRT3vWF73lgtwa1v9jEZPV/Sc1+ojW5S5a9/YXNAHODSW5FhLivXtVS96ze5aF921LjRNe8u9HxXOTSqB/22Mpmls7Yvw671jJPNlZlIFrllSS89kmtKsRUIo+eJeObGzrHC
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sns.scatterplot(x='Gene One',y='Gene Two',hue='Cancer Present',data=df,alpha=0.7)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x2657fb62ac8>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcoAAAEKCAYAAACSdBVoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3gbZdbFf6PeJRe5927HidMTSA+h9w6hs5QFlrIdlm1spSzLLm3pCyx16b1D6KR3J3GJHfdeZPU23x9SxhrLAQcCCfvpPA8P0XikGb2amfu+5557riCKIgkkkEACCSSQwPhQ7O8TSCCBBBJIIIEDGYlAmUACCSSQQAJfgkSgTCCBBBJIIIEvQSJQJpBAAgkkkMCXIBEoE0gggQQSSOBLkAiUCSSQQAIJJPAl+NYCpSAIDwmC0CMIwpaYbcmCILwjCEJ99P9J39bxE0gggQQSSGBf4NtcUT4MHDFm27XAe6IolgLvRV8nkEACCSSQwAEL4ds0HBAEoQB4VRTF6ujrHcBiURQ7BUHIBFaIolj+rZ1AAgkkkEACCXxDqL7j46WLotgZ/XcXkL6nHQVBuAS4BMBoNM6oqKj4Dk4vgQS+GgMuP+1DHtk2ASjPMKNWJtL+CRw4WLt2bZ8oivb9fR7fd3zXgVKCKIqiIAh7XM6KongfcB/AzJkzxTVr1nxn55ZAAl+GL3b2c8Z9X8i2nTA1i5tOmYJWpdxPZ5VAAvEQBGHX/j6H/wV819Pf7ijlSvT/Pd/x8RNI4BtjcraVP59QjVETCYrzSlK4cmlpIkgmkMD/KL7rFeXLwHnAjdH/v/QdHz+BBL4xjFoVZ83NZ0GZHW8gRJZNh0mr3ufHCYdFdvW7GPYGybLpSDPr9vkx/pfQM+KlY8iLVaciP8WIQiHs71NK4H8E31qgFAThSWAxkCoIQhvwOyIB8r+CIPwA2AWc9m0dP4EEvm3kJRu+tc/2BUK8uKGD3728BW8gTE6SnrvPms6UHNu3dszvMza1DXH54+toG/SgUyu44bhqTpiahVadWOUn8M3xrape9xUSOcoE/r9hY+sQx9/1qWxbRYaZJy+ZS5JBs5/O6sDEoNvPmfd9wfauEdn2l3807//9xEIQhLWiKM7c3+fxfUdCopdAAgcg2gbdcdu2d43QO+LbD2dzYKN3xBcXJAFaB+LH8LtA26Cbpl4nvmBovxw/gX2P/aZ6TSCBBPaMNEt8PjLLqsNm2Pe50O87bHo1WVYdHcNe2fb0ccbw28SIJ8Dz69u45a063P4gJ07L5pplZeR+ixR9At8NEivKBBI4AFGZYeaKJSXSa51awc2n1CQEPeMgzaLj5lNq0KlHH2dXLCmhIsP8nZ7HhtYhfvdyLU5fkLAIz61r56nVLXwf0lsJfDkSK8oEEjgAYdKpuWJxMUsr7PQ7/RSmGilJM+3v0zpgMa8khdevWkBLv5skk4ZSuwmD9rt9vK1rGYzb9uL6Dn4wv5Bko/Y7PZcE9i0SgTKBBA5QbOkY5s+vb6Oh28kxUzL54eJiClMTwXI8CIJAkd1EkX3/jc94FGt5hgmDJvGY/b4jQb0mkMABiLruEc55cBUbW4dx+UM8vaaNP726Dbc/uL9PLYE9YFZBMtXZFum1UaPkyqWl6BIlKt97JKY6CSRwAKKxx4kvGJZte297Dx1D3gQFe4AiN9nA/efOZHvnCN5AiJI0E6Xp8XnSUFjE6Qtg0qpRJkwRvhdIBMoEEjgAYdLF35oWnUomWEngwEOmVU+mVb/Hvzf0OHli5S4+2NHLgpJUzjkof9xgmsCBhUSgTCCBAxAVGRbml6TySUOftO36oyrJSUqUGnxf0e/0cc1T69nS4QCgqc/F5039PHHRHOwJNfMBjUSgTCCBAxB2s5a/nVrDprYh+pw+StLMsvxXAt8/NPe7pCC5G/XdTpr6XIlAeYAjESgT2GfoGvayYkcPr27qZEa+jWOmZH3ntFI4LLKhdYj/rmllwOXnjNl5zClMxvgdlwrsC2RYdWRYM/b3aSSwj6BRji/q0exlD9OtHcO8uL6duu4RTp6ey/zSlET5ybeM79/TI4EDEv5giLs+qOc/X7QA8ElDH8+ta+e/lx5Elm3POZt9jS3tw5xx3xf4QxEhzNu13dxz9nSOqM78zs4hgQTGQ6HdyMnTs3luXbu07egpmXtV0tLY42T5/SsZ9gQA+LCuj+uPquTihUX7/HwTGEUiUCawT9A64OHxlS2ybW2DHuq6R77TQPlxQ58UJHfjrg8aWFhmT9SzJbBfYdKq+MUR5SwpT2ND6xCTc6zMLkzGop+4LWFtp0MKkrtx+3v1HDMlk8zv8D77/4bEkyOBCaOx18mnDX20DLiZX5LKzPwkTLrRm1wQBDgA7boORAF+KCyysW2ID3f0olYKLCi1MyXHGhnDBP5nkW7Rc0yNnmNqsvbp5yYum28XCa15AhNCS7+L8x5cyW9f2soDHzdx/r9X88qmTunvucl6zpqTJ3tPbpKesu84Rzm/JDUu53P5kpIDbjW5btcgp97zOf98r56/vV3Hafd+zsa24f19Wgkc4KjKtMQZ4191SCkZX1KSksA3x4H19EjggMXWDgdtQ/LuDDe/uZ2lFXbSLXo0KiVXLC6mOsvKa5s7mJ6XxFGTM79T2hVgSo6Vpy6Zy7Nr2xhw+zl9Zi6zC5O/03P4KoiiyKOfNxMKj66+fcEwr23qYGqubf+dWAIHPIrTTDxx0Rxe2thBfY+TE6dmM68kZX+f1v88EoEygQlhrEsMgNsfIhgafdinW/WcNiuX02blfnvnEQixpcPBji4HNoOGKTlWWW2hIAhMz09ien7St3YO3xRhEYbG5Jlg/G1fBk8gyJY2B3XdI6SYtEzJsX7nE5MEvntY9GpqcmzkJxvIsukxHmBsyf8iEiOcwIRQmWlGr1biCYw2o/3B/MIvdSH5NvD+jh4ue2yd9HpSpoX7z51JVtL3J0AoFQLnHpTPx/V9su3H72Xe6q0t3Vzz9Abp9fQ8G3efNYMMa6Im738V7YNuLn10LVs7R+sxE6rubx+JHGUCE0J5hoUnLp7DEZMyKE0z8eujKzn3oAIU36FXZd+IjxterpVt29rpYHPH9y+3d1BRKnefNZ2aHCsz85N48LyZzCiY+Cq4a9jDH16Vj8W6liFqO79/Y5HAxLGl3SELkgA3vFJL34hvP53R/w8kVpQJTBjT8pK4/cyp+IJhzLqJS9r3FbzBEP2u+AeCy/f966hh0qk4anImi8vtCAjoNXvXYcIbCDPk9sdtd/lC4+z9zbCja4TGHicGrZLKTAvpFvmKtdvhZVunA7cvRHGaifJvoWFyn9PH9s4RBj1+ilKNlKebUcWItvzBMDu6R2juc5Fs1FCVaSZpL4vwhz0BtnU66B3xkZtsoDLDjPYA6/zhHOda73P68Ab3/e+ewCgSgTKBvYJGpUSj2j8Pj3SLjtNm5srqNZUK4XttKv111biZNh3H1mTx0oYOaZtaKVC8j/sxrm4a4OwHV0o56ul5Nu44cxrZ0bxw+6CbK59cz7qWIQC0KgWPXTSHWQX7TkDV5/Tx6xc38+aWbiDym993zgwOqUyX9nlvezeXP75Oqk46cVoWvz1mEklGzYSO4fQG+Mc7dfz7s2Zp282nTOHUGTkHVMlOWboJpUKQCcFOn5kbN3lJYN8iQb0m8L2BWqngh4uKuXBeARadispMC49cMIuqzP9/HqhalZKfHFrGOXPzMGtVTM628OiFs6nM3HeTBpc3yE1vbZcJuda1DLGhdUh6vb51SAqSEBF93fTmdlzefbfKr+1wSEESIjWov35xC70jERV2x5CHX7+wRVbC+8L6DrZ3OcZ+1B5R1+2UBUmA37+8lZZ+9zc6932Nqiwrj1wwi8pMCxadigvnFXDpomLUe2mDl8DeYb+sKAVBuBq4mEgt+P2iKP5jf5xHAt8/5CYbuP7oKi5ZWIxeo8Cqj18xDLj81HeP4A+FKbabvrdKUJcvSEPPCAOuAHnJBorsRtnqJj/FyO+OncTlS0owalR75fAyoeMHgjT2OOO2dztG6e8eRzwV3tjjxBUIYhynVdjXweA
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sns.scatterplot(x='Gene One',y='Gene Two',hue='Cancer Present',data=df)\n",
"plt.xlim(2,6)\n",
"plt.ylim(3,10)\n",
"plt.legend(loc=(1.1,0.5))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train|Test Split and Scaling Data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"X = df.drop('Cancer Present',axis=1)\n",
"y = df['Cancer Present']"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"scaled_X_train = scaler.fit_transform(X_train)\n",
"scaled_X_test = scaler.transform(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"knn_model = KNeighborsClassifier(n_neighbors=1)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(n_neighbors=1)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"knn_model.fit(scaled_X_train,y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Understanding KNN and Choosing K Value"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"full_test = pd.concat([X_test,y_test],axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"900"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(full_test)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:xlabel='Gene One', ylabel='Gene Two'>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAD2hElEQVR4nOy9d5xcZ33v/z7nTO99e9/V7qp3y5Yl23JvYGIgBhJqKLmUQAL3QnJ/5N78uEkIuQn8AiQxgUAwzQZjgwF3Ycm2ei8rbe99ei/nnN8fZ3ak1a6am1zm/XrpJc2jOWeemTnzPc/zLZ+voKoqZcqUKVPmrYN4pSdQpkyZMmVeW8qGv0yZMmXeYpQNf5kyZcq8xSgb/jJlypR5i1E2/GXKlCnzFkN3pSdwKfh8PrWxsfFKT6NMmTJl3lAcOHBgVlVV/7njbwjD39jYyP79+6/0NMqUKVPmDYUgCEOLjZddPWXKlCnzFqNs+MuUKVPmLUbZ8JcpU6bMW4w3hI+/TJkybz7y+Tyjo6NkMpkrPZU3PCaTidraWvR6/SU9v2z4y5Qpc0UYHR3FbrfT2NiIIAhXejpvWFRVJRgMMjo6SlNT0yUd86q5egRB+J4gCNOCIBw/a8wjCMJTgiD0FP92v1qvX+b1w1g4zdMnp/jN0QlOTcQoCwOWAchkMni93rLRf5kIgoDX672sndOr6eP/PnDbOWNfBJ5RVbUNeKb4uMybmOFQin988hQ/2TvMwwdH+acnuzkwGL7S0yrzOqFs9F8ZLvdzfNUMv6qqO4DQOcNvB35Q/PcPgHterdcv8/rg8HCYWLpQeiyrKr89PkE2L1/BWZUp89bmtc7qqVBVdaL470mg4nxPFAThY4Ig7BcEYf/MzMxrM7syrziRdH7BWDiVI1tQrsBsyrzemZyc5L777qOlpYV169Zxxx130N3dfaWnxfXXX097ezurVq1i8+bNnD59+jWfw+DgID/+8Y9fkXNdsXROVXP0ntfZq6rq/aqqrldVdb3fv6DiuMwbhI4K+4Kx9Q0eHOZLyz54rYmmc/RMxZmJZa/0VN5yqKrKO97xDq6//nr6+vo4cOAAf/d3f8fU1NRrPg9FWbgw+dGPfsSRI0f4wAc+wBe+8IUF/y/Lr+4u9o1s+KcEQagCKP49/Rq/fpnXmFV1Lt65rharUUInCVzT4uWWZefd6F1Rjo1G+MpjJ/n7353ifz92gu2np1GUciD6tWL79u3o9Xo+8YlPlMZWrVrFli1bSCQS3Hjjjaxdu5YVK1bw6KOPApox7Ozs5KMf/SjLli3jlltuIZ1OA9Db28tNN93EqlWrWLt2LX19fQB87WtfY8OGDaxcuZK//uu/Lp2nvb2d97///SxfvpyRkZHzznPr1q309vYCYLPZ+Iu/+AtWrVrFrl27eOCBB9i4cSOrV6/m4x//OLIsI8syH/zgB1m+fDkrVqzgn//5nwHo6+vjtttuY926dWzZsoVTp04B8MEPfpDPfOYzXHPNNTQ3N/Pzn/8cgC9+8Yvs3LmT1atXl87xklFV9VX7AzQCx896/DXgi8V/fxH4h0s5z7p169Qyb2yCiaw6GU2riqJc6aksSiiRVf/8Z4fUD//n3tKfP/n+XrV3Kn6lp/am5eTJk/Mef+Mb31A/+9nPLvrcfD6vRqNRVVVVdWZmRm1paVEVRVEHBgZUSZLUQ4cOqaqqqu9617vUH/7wh6qqqurGjRvVhx9+WFVVVU2n02oymVSfeOIJ9aMf/aiqKIoqy7J65513qs8995w6MDCgCoKg7tq1a9HXv+6669R9+/apqqqq//AP/6C++93vVlVV81r87Gc/K72fu+66S83lcqqqquqf/umfqj/4wQ/U/fv3qzfddFPpXOFwWFVVVd22bZva3d2tqqqq7t69W73hhhtUVVXVD3zgA+o73/lOVZZl9cSJE2pLS4uqqqq6fft29c4777zkz7M4v/3qIjb1VcvjFwThJ8D1gE8QhFHgr4G/Bx4UBOEjwBDw7lfr9cu8vvBYDVd6ChdkOp4hkpofj1BUGIumaQnYrtCsysyhqip/+Zd/yY4dOxBFkbGxsZILqKmpidWrVwOwbt06BgcHicfjjI2N8Y53vAPQCpwAnnzySZ588knWrFkDQCKRoKenh/r6ehoaGti0adN55/C+970Ps9lMY2Mj//Iv/wKAJEnce++9ADzzzDMcOHCADRs2AJBOpwkEAtx999309/fz6U9/mjvvvJNbbrmFRCLBiy++yLve9a7S+bPZM+7Fe+65B1EUWbp06avi6nrVDL+qqu85z3/d+Gq9ZpkyLxWHSY9BJ5I7J+jsfp3GIt6MLFu2rOTWOJcf/ehHzMzMcODAAfR6PY2NjaW8daPRWHqeJEklV89iqKrKl770JT7+8Y/PGx8cHMRqtV5wfj/60Y9Yv379vDGTyYQkSaVzf+ADH+Dv/u7vFhx75MgRnnjiCf7t3/6NBx98kK9//eu4XC4OHz686Gud/Z7UV6HupazVU6YMUOUyc8+aas7Ohr6m1UtrRXm1/1qxbds2stks999/f2ns6NGj7Ny5k2g0SiAQQK/Xs337doaGFlUbLmG326mtreWRRx4BtNV0KpXi1ltv5Xvf+x6JRAKAsbExpqdfmVDjjTfeyM9//vPS+UKhEENDQ8zOzqIoCvfeey9f+cpXOHjwIA6Hg6amJh566CFAM+5Hjhy56HuKx+OvyFzLkg1lyhS5saOCRq+V8UgGj9VAW4UNs778E3mtEASBX/7yl3z2s5/lq1/9KiaTicbGRr7+9a/zvve9j7vvvpsVK1awfv16Ojo6Lnq+H/7wh3z84x/ny1/+Mnq9noceeohbbrmFrq4urr76akALzj7wwAOlVfvLYenSpXzlK1/hlltuQVEU9Ho93/rWtzCbzXzoQx8qZQrN7Qh+9KMf8ad/+qd85StfIZ/Pc99997Fq1arznn/lypVIksSqVav44Ac/yOc+97mXPFfh1dhGvNKsX79eLTdiKfNGYSae4fhYjFgmT4PXQmeVA6Pu5RuWNxtdXV10dnZe6Wm8aVjs8xQE4YCqquvPfW55OVOmzCtIMJHlW9t7GQmd8TO/c10tt6+ouoKzKlNmPmUff5mXREFW6J6K80LvLN2TcfJyuRIXoGsiPs/oA/zm6AQz8XJBWJnXD+UVf5nLRlVVnjgxyS8PjaGqIAB3rqzi7lXV6KS39loilS8sGEvnZdJlbaIyryPe2r/StyipXIGeqTjDodRLqkwdCqZ49PA4c+EhFfjtsQmGQqlXdqJvQBo9VsRzhBLbK+1UOoyLH1CmzBWgvOJ/izEUTPJfLw4yGEwhiQLXt/t52+pqbMZLz1cPp3LI59wwFBVCyRwtb3FZpbYKGx/e0swjB8eIpHMsrXLwB2trMZSDu2VeR5QN/1uIgqzw6yPjDAa1lbmsqDzTNU2Tz8rVLb5LPk+Fw4RRJ85T2NRJAhUO0ys+51SuQDon47Ea3hDa7YIgcHWzl+XVDlI5Ga/VsMD9lcnLJLMFnGb9W941VubKUL7q3kLEMnlOTSwsABmcvTwXTbXLzB9tasCs11axJp3Iezc2UOc2vyLzBC2OsG8gyN/+posvP3qc+3f0MxY+f0Xm6w27SU+Fw7TAsB8ZifAPj5/iy48c519/38fgbPIKzbDMq8Xjjz9Oe3s7ra2t/P3f//2Vns6ilFf8byGsBh0VTuMCQ+9/Cf7na1p9tARszCayeKwGqpyvnNEH6JlO8J2dAyWX0t6BEKmczCdvaMWge2OuV4ZDKf7tub6SLMThkQihZI4v3NaOxVD+Kb4ZkGWZT37ykzz11FPU1tayYcMG3va2t7F06dIrPbV5lK+2txBGvcTbV9fwr78/Y3wafVZW1Tpf0vkqHKZXxb0DMDCbXBBHODEWZTqeodZteUVeQ1FUjoxG2NkzQ7agsLnVx7oG96LFVuFUlt19IQ4Oh6l2mdna5qMlsLDXwIUYCiYXaAENh1KMhdO0LdK3oMyrzyOHxvjaE6cZj6Spdpn5wq3t3LOm5iWfb+/evbS2ttLc3AzAfffdx6OPPlo2/GWuLCtrXfzlHZ0MBpO
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sns.scatterplot(x='Gene One',y='Gene Two',hue='Cancer Present',\n",
" data=full_test,alpha=0.7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"y_pred = knn_model.predict(scaled_X_test)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import classification_report,confusion_matrix,accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8922222222222222"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(y_test,y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[420, 50],\n",
" [ 47, 383]], dtype=int64)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(y_test,y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.90 0.89 0.90 470\n",
" 1 0.88 0.89 0.89 430\n",
"\n",
" accuracy 0.89 900\n",
" macro avg 0.89 0.89 0.89 900\n",
"weighted avg 0.89 0.89 0.89 900\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Elbow Method for Choosing Reasonable K Values\n",
"\n",
"**NOTE: This uses the test set for the hyperparameter selection of K.**"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"test_error_rates = []\n",
"\n",
"\n",
"for k in range(1,30):\n",
" knn_model = KNeighborsClassifier(n_neighbors=k)\n",
" knn_model.fit(scaled_X_train,y_train) \n",
" \n",
" y_pred_test = knn_model.predict(scaled_X_test)\n",
" \n",
" test_error = 1 - accuracy_score(y_test,y_pred_test)\n",
" test_error_rates.append(test_error)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'K Value')"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABq4AAAQKCAYAAAAxe4tGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAB7CAAAewgFu0HU+AADcRklEQVR4nOzdeZieVX0//vfJZJ8kJDOsIZkBFBQEBEEUBJe6tIq71t2KQq3tT6V16fKttfZr1dbdal0qoLZYbft131p3CgqKLKKyS5gEwjpJIJnsmfv3xzxJ5pmsk8w89yyv13U9132f85z7Pp8JTbyuefecU6qqCgAAAAAAANRtSt0FAAAAAAAAQCK4AgAAAAAAYIwQXAEAAAAAADAmCK4AAAAAAAAYEwRXAAAAAAAAjAmCKwAAAAAAAMYEwRUAAAAAAABjguAKAAAAAACAMUFwBQAAAAAAwJgguAIAAAAAAGBMEFwBAAAAAAAwJgiuAAAAAAAAGBMEVwAAAAAAAIwJgisAAAAAAADGBMEVAAAAAAAAY4LgCgAAAAAAgDFBcAUAAAAAAMCYILgCAAAAAABgTJhadwGMH6WUGUlOaDTvS7KlxnIAAAAAAIB6tSU5qHH/q6qqNuzvCwVXDMcJSa6suwgAAAAAAGDMeXSSX+zvS2wVmKSU0l1K+UAp5cZSSl8pZUUp5cpSyltLKbP3891TSinHlVLOKaV8vPHeDaWUqvF54l6+Z04p5fGllLeUUv6zlLJk0Dtu358aAQAAAAAAxoJJv+KqlPKsJBcnmTeoe3aSUxuf80opZ1dVdes+TvHKJJ/dryIHfCPJE0fgPfvjvq03P//5z3PYYYfVWQsAAAAAAFCju+66K6eddtrW5n27G7u3JnVwVUo5Ocl/JJmVZE2S9yT5UaP9kiR/mOSYJN8qpZxaVdXqfZlm0P2mJL9KMi3bz4ral/esyMByuzOSzNmHmvbVtjOtDjvssCxatKiFUwMAAAAAAGPYlj0P2bPJvlXgRzIQUm1O8rSqqt5dVdXlVVX9sKqq1yb588a4Y5K8eR/nuD7JG5OcnmReVVWnJPnyPrzn35O8LMnRVVV1VlX1u0l697EmAAAAAACAMWfSrrgqpZyW5KxG88Kqqi7fybAPJHl1kmOTnF9KeVdVVZuGM09VVT9P8vP9KnbgPf+yv+8AAAAAAAAYyybziqvnDrr/zM4GVFXVn+RfG835SZ40uiUBAAAAAABMXpM5uDqzce1LctVuxl0y6P5xo1cOAAAAAADA5DZptwrMwPZ/SXJrVVWbdzPuxp08MyGVUhbtYcihLSkEAAAAAACYlCZlcFVKmZnkwEbzjt2NrapqZSmlL0l7ksWjXVvNltVdAAAAAAAAMHlN1q0C5w66X7MX4/sa1zmjUAsAAAAAAACZpCuukswcdL9xL8ZvaFxnjUItY8meVpQdmuTKVhQCAAAAAABMPpM1uFo/6H76Xoyf0biuG4Vaxoyqqna7bWIppVWlAAAAAAAAk9Bk3Spw9aD7vdn+r71x3ZttBQEAAAAAANgHk3LFVVVV60spvUk6kyza3dhSyoJsD66WjXZtAAAAAACTxfr167Nq1aqsXbs2W7ZsqbscmFTa2toyffr0zJs3L3PmzMmUKWNjrdOkDK4ark9yVpKHllKmVlW1eRfjHj7o/obRLwsAAAAAYGKrqip33XVXHnjggbpLgUlr8+bN2bBhQ1avXp1SSg4//PDMnTu37rImdXB1WQaCq/YkpyT52S7GPWHQ/U9GuygAAAAAgImut7d3h9Bq6tTJ/OtqaL0tW7akqqokA2HynXfeOSbCq8n8L8FXk/xV4/7V2UlwVUqZkuQPGs1VSX7UisIAAAAAACaqjRs35r777tvWPvjggzN//vy0tbXVWBVMPlVVZe3atVmxYkXWrFmzLbw65phjat02cGxsWFiDqqp+nuTSRvPcUsrpOxn25iTHNu4/UlXVpsFfllKeWEqpGp/Pjl61AAAAAAATw5o1a7bdd3Z2prOzU2gFNSilpL29PYsWLcqcOXOSDIRZg/+O1mEyr7hKkvMzsP3frCTfLaW8OwOrqmYleUmS1zbG3ZzkA/s6SSnlnCFdJw26/71SyhGD2rdWVXXZTt7x0CRnDumes/W6kzn+u6qqu4ddLAAAAADAKOrr69t2P2/evBorAZKBAKujo2NbYPXggw/W+ndzUgdXVVVdU0p5cZKLk8xL8u6dDLs5ydlVVa3ej6k+s5vv/mJI+3MZOH9rqDN3857OnXz3pCSCKwAAAABgTNm4cWOSgV+Wz5gxo+ZqgCSZPXt2Simpqmrb39G6TNqtAreqquobSU5M8qEMhFRrM3Ce1S8yECqdXFXVrbUVCAAAAAAwgfT39ydJ2traUkqpuRogGQiSt27ZuWXLllprmdQrrraqqqonyZsan+E89+Mke/yXtaqq/f7Xt6qqzyb57P6+BwAAAAAAYKya9CuuAAAAAAAAGBsEVwAAAAAAAIwJgisAAAAAAADGBMEVAAAAAAAAY4LgCgAAAAAAgDFBcAUAAAAAABPI7bffnlLKfn/GuyOOOGJYP+/8+fPrLpkIrgAAAAAAgDHgiU98YkopeeITn1h3KdRoat0FAAAAAAAAI+fwww/Pr371q11+f8IJJyRJTj311HzmM59pVVm1WbhwYf7nf/5nj+Pa2tpaUA17IrgCAAAAAIAJZNq0aTn++OP3OK69vX2vxo13e/vnwdhgq0AAAAAAAADGBMEVAAAAAADQ5Oqrr87rXve6POxhD8ucOXPS3t6ehz3sYfnjP/7j3Hzzzbt9dtWqVXnXu96V008/PQsWLMi0adNy0EEH5bjjjsvznve8fOITn8g999yzbfw555yTUkouueSSJMkll1ySUkrT54gjjhjNH3eX3vGOd2yrIUkeeOCBvPOd78zJJ5+c+fPnp5SSz372s8Meu9WaNWvyD//wDzn99NPT0dGRGTNmZNGiRXnhC1+Yb37zm7utbeiZYLfcckte//rX5+ijj87s2bNTSsntt98+kn8cLWGrQAAAAAAAIEnS39+ft7zlLfnwhz+cqqqavrv55ptz880354ILLsg///M/57Wvfe0Oz99www15ylOekuXLlzf133///bn//vtzww035Ktf/Wq2bNmS17/+9aP6s4y0W265JU972tP2Kgzam7HXXHNNnvnMZ+7wZ3XnnXfmS1/6Ur70pS/l+c9/fj7/+c9n5syZu53va1/7Wl7+8penr69vb36UMU1wBQAAAAAAJEne8IY35OMf/3iS5PGPf3zOOeecHHXUUZk9e3Z++ctf5sMf/nB+85vf5I/+6I9y6KGH5tnPfnbT86985SuzfPnyTJs2LX/4h3+Ypz/96Tn00EPT39+fO+64I1dccUW+8pWvND3zrne9K295y1vy6le/Or/4xS9y6qmn5jOf+UzTmOnTp4/uD74XXvjCF+bOO+/MG97whjz72c/OggULcsstt6S7u3vYY++88848+clPzsqVK1NKyTnnnJOXvOQl6ezszPXXX58PfOAD+eUvf5kvf/nLOeecc/LFL35xl3UtXbo0r3jFKzJ79uz8zd/8Tc4666y0tbXlyiuvzJw5c0btz2O0CK4AAAAAABhT+vurrFy7se4yWmbB7OmZMqXUXUa+973vbQutLrjggpx77rlN3z/60Y/OK17xipx99tn54Q9/mDe+8Y15xjOekalTB6KG2267LVdddVWS5IMf/OAOK6pOO+20PP/5z88//uM/ZtWqVdv6Dz/88Bx++OFpb29PkrS3t+f4448fsZ9r06ZN+fWvf73HcQcffHAOPvjgXX7/61//Ot/5znfytKc9bVvfKaecsk9j//RP/zQrV65Mknz6059u+rM+5ZRT8qIXvShPf/rT86Mf/Sj/8R//kVe96lV5+tOfvtO5lixZkoULF+byyy9PV1fXtv7HPOYxe/iJxybBFQAAAAAAY8rKtRtzyt9/v+4yWuaqtz0lnXNm1F1G/uEf/iFJ8oIXvGCH0GqrmTNn5mMf+1iOO+649PT05Ec/+lGe+tSnJknuvvvubeMe//jH73KeUkoWLFgwgpX
"text/plain": [
"<Figure size 2000x1200 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,6),dpi=200)\n",
"plt.plot(range(1,30),test_error_rates,label='Test Error')\n",
"plt.legend()\n",
"plt.ylabel('Error Rate')\n",
"plt.xlabel(\"K Value\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full Cross Validation Grid Search for K Value"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a Pipeline to find K value\n",
"\n",
"**Follow along very carefully here! We use very specific string codes AND variable names here so that everything matches up correctly. This is not a case where you can easily swap out variable names for whatever you want!**\n",
"\n",
"We'll use a Pipeline object to set up a workflow of operations:\n",
"\n",
"1. Scale Data\n",
"2. Create Model on Scaled Data\n",
"\n",
"----\n",
"*How does the Scaler work inside a Pipeline with CV? Is scikit-learn \"smart\" enough to understand .fit() on train vs .transform() on train and test?**\n",
"\n",
"**Yes! Scikit-Learn's pipeline is well suited for this! [Full Info in Documentation](https://scikit-learn.org/stable/modules/preprocessing.html#standardization-or-mean-removal-and-variance-scaling) **\n",
"\n",
"When you use the StandardScaler as a step inside a Pipeline then scikit-learn will internally do the job for you.\n",
"\n",
"What happens can be discribed as follows:\n",
"\n",
"* Step 0: The data are split into TRAINING data and TEST data according to the cv parameter that you specified in the GridSearchCV.\n",
"* Step 1: the scaler is fitted on the TRAINING data\n",
"* Step 2: the scaler transforms TRAINING data\n",
"* Step 3: the models are fitted/trained using the transformed TRAINING data\n",
"* Step 4: the scaler is used to transform the TEST data\n",
"* Step 5: the trained models predict using the transformed TEST data\n",
"\n",
"----"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"knn = KNeighborsClassifier()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['algorithm', 'leaf_size', 'metric', 'metric_params', 'n_jobs', 'n_neighbors', 'p', 'weights'])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"knn.get_params().keys()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# Highly recommend string code matches variable name!\n",
"operations = [('scaler',scaler),('knn',knn)]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"pipe = Pipeline(operations)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"*Note: If your parameter grid is going inside a PipeLine, your parameter name needs to be specified in the following manner:**\n",
"\n",
"* chosen_string_name + **two** underscores + parameter key name\n",
"* model_name + __ + parameter name\n",
"* knn_model + __ + n_neighbors\n",
"* knn_model__n_neighbors\n",
"\n",
"[StackOverflow on this](https://stackoverflow.com/questions/41899132/invalid-parameter-for-sklearn-estimator-pipeline)\n",
"\n",
"The reason we have to do this is because it let's scikit-learn know what operation in the pipeline these parameters are related to (otherwise it might think n_neighbors was a parameter in the scaler).\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"k_values = list(range(1,20))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k_values"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"\n",
"param_grid = {'knn__n_neighbors': k_values}"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"full_cv_classifier = GridSearchCV(pipe,param_grid,cv=5,scoring='accuracy')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[('scaler', StandardScaler()),\n",
" ('knn', KNeighborsClassifier())]),\n",
" param_grid={'knn__n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,\n",
" 12, 13, 14, 15, 16, 17, 18, 19]},\n",
" scoring='accuracy')"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Use full X and y if you DON'T want a hold-out test set\n",
"# Use X_train and y_train if you DO want a holdout test set (X_test,y_test)\n",
"full_cv_classifier.fit(X_train,y_train)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'memory': None,\n",
" 'steps': [('scaler', StandardScaler()),\n",
" ('knn', KNeighborsClassifier(n_neighbors=14))],\n",
" 'verbose': False,\n",
" 'scaler': StandardScaler(),\n",
" 'knn': KNeighborsClassifier(n_neighbors=14),\n",
" 'scaler__copy': True,\n",
" 'scaler__with_mean': True,\n",
" 'scaler__with_std': True,\n",
" 'knn__algorithm': 'auto',\n",
" 'knn__leaf_size': 30,\n",
" 'knn__metric': 'minkowski',\n",
" 'knn__metric_params': None,\n",
" 'knn__n_jobs': None,\n",
" 'knn__n_neighbors': 14,\n",
" 'knn__p': 2,\n",
" 'knn__weights': 'uniform'}"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_cv_classifier.best_estimator_.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time', 'param_knn__n_neighbors', 'params', 'split0_test_score', 'split1_test_score', 'split2_test_score', 'split3_test_score', 'split4_test_score', 'mean_test_score', 'std_test_score', 'rank_test_score'])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_cv_classifier.cv_results_.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check our understanding:\n",
"**How many total runs did we do?**"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(k_values)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.90238095, 0.90285714, 0.91857143, 0.91333333, 0.92380952,\n",
" 0.92142857, 0.9252381 , 0.9247619 , 0.9252381 , 0.92190476,\n",
" 0.9252381 , 0.9247619 , 0.92761905, 0.92904762, 0.92809524,\n",
" 0.92809524, 0.92904762, 0.92857143, 0.92761905])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_cv_classifier.cv_results_['mean_test_score']"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(full_cv_classifier.cv_results_['mean_test_score'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Final Model\n",
"\n",
"We just saw that our GridSearch recommends a K=14 (in line with our alternative Elbow Method). Let's now use the PipeLine again, but this time, no need to do a grid search, instead we will evaluate on our hold-out Test Set."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler()\n",
"knn14 = KNeighborsClassifier(n_neighbors=14)\n",
"operations = [('scaler',scaler),('knn14',knn14)]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"pipe = Pipeline(operations)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('scaler', StandardScaler()),\n",
" ('knn14', KNeighborsClassifier(n_neighbors=14))])"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipe.fit(X_train,y_train)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"pipe_pred = pipe.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.93 0.95 0.94 470\n",
" 1 0.95 0.92 0.93 430\n",
"\n",
" accuracy 0.94 900\n",
" macro avg 0.94 0.94 0.94 900\n",
"weighted avg 0.94 0.94 0.94 900\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test,pipe_pred))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"single_sample = X_test.iloc[40]"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Gene One 3.8\n",
"Gene Two 6.3\n",
"Name: 194, dtype: float64"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"single_sample"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0], dtype=int64)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipe.predict(single_sample.values.reshape(1, -1))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.92857143, 0.07142857]])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipe.predict_proba(single_sample.values.reshape(1, -1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"----"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}