2022-12-29 17:27:20 +01:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 1,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"%matplotlib inline"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"\n",
|
|
|
|
"# Recognizing hand-written digits\n",
|
|
|
|
"\n",
|
|
|
|
"This example shows how scikit-learn can be used to recognize images of\n",
|
|
|
|
"hand-written digits, from 0-9.\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>\n",
|
|
|
|
"# License: BSD 3 clause\n",
|
|
|
|
"\n",
|
|
|
|
"# Standard scientific Python imports\n",
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"\n",
|
|
|
|
"# Import datasets, classifiers and performance metrics\n",
|
|
|
|
"from sklearn import datasets, svm, metrics\n",
|
|
|
|
"from sklearn.model_selection import train_test_split"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Digits dataset\n",
|
|
|
|
"\n",
|
|
|
|
"The digits dataset consists of 8x8\n",
|
|
|
|
"pixel images of digits. The ``images`` attribute of the dataset stores\n",
|
|
|
|
"8x8 arrays of grayscale values for each image. We will use these arrays to\n",
|
|
|
|
"visualize the first 4 images. The ``target`` attribute of the dataset stores\n",
|
|
|
|
"the digit each image represents and this is included in the title of the 4\n",
|
|
|
|
"plots below.\n",
|
|
|
|
"\n",
|
|
|
|
"Note: if we were working from image files (e.g., 'png' files), we would load\n",
|
|
|
|
"them using :func:`matplotlib.pyplot.imread`.\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASFklEQVR4nO3db5CVZd0H8N8KsRsBsiLkkiUsOmPJIAHNJCbgsBCkBkmgLxhZxgYqGaM/M8sU5oJlkjZjhRnxBgNzlDLIJlMY2JymN7GyloYzSyw6GU6Kyx9F/no/L57HfaIld8Hr8rC7n88MM+x1zv29rwP82POd++w5ZUVRFAEAAJDYOaXeAAAA0D0pGwAAQBbKBgAAkIWyAQAAZKFsAAAAWSgbAABAFsoGAACQhbIBAABkoWwAAABZKBtnoLa2NoYNG3ZGx9bX10dZWVnaDcFZyJxAx8wJdMycdG3dqmyUlZV16ldDQ0Opt3rW+dOf/hSf+tSnom/fvnHBBRfErbfeGq+//nqpt0UG5uTMPPnkk3HzzTfHyJEjo1evXmf8jY+uwZycvkOHDsV9990XU6dOjaqqqujfv398/OMfj/vvvz9OnDhR6u2RgTk5M3feeWd88pOfjMGDB0dFRUVccsklsXjx4njllVdKvbUsyoqiKEq9iVTWrVt30tc///nPY9OmTbF27dqT1qdMmRIf/OAHz/g8x44di7feeivKy8tP+9jjx4/H8ePHo6Ki4ozPn1pTU1NcccUV8dGPfjQWLFgQ//jHP+Kee+6Jq6++Oh5//PFSb4/EzMmZqa2tjYcffjjGjBkTL774YvTq1St2795d6m2RiTk5fc8++2yMGjUqJk+eHFOnTo0BAwbEE088Eb/+9a/jpptuigceeKDUWyQxc3JmZs2aFYMHD45LL700+vfvHzt27IjVq1fHkCFDoqmpKT7wgQ+UeotpFd3YLbfcUnTmIb7xxhvvwW7OXtOnTy+qqqqK/fv3t62tXr26iIjiiSeeKOHOeC+Yk8556aWXiqNHjxZFURTXXHNNcdFFF5V2Q7ynzEnHXnnlleLZZ59ttz5//vwiIorm5uYS7Ir3kjk5c7/85S+LiCgeeuihUm8luW71MqrOmDRpUowcOTIaGxtjwoQJ0bdv3/jmN78ZEREbN26Ma665JoYOHRrl5eUxYsSIuOOOO9pd/v3P1w7u3r07ysrK4p577omf/exnMWLEiCgvL49PfOIT8ec///mkY0/12sGysrJYtGhRbNiwIUaOHBnl5eVx2WWXxe9///t2+29oaIhx48ZFRUVFjBgxIlatWnXKzFdffTWef/75OHTo0Dv+eRw4cCA2bdoUc+fOjQEDBrSt33TTTdGvX7945JFH3vF4uidz0t7QoUPjfe97X4f3o+cwJyc7//zz47LLLmu3/rnPfS4iInbs2PGOx9M9mZPOefvx7du374yOP5v1LvUGSmHv3r0xffr0uPHGG2Pu3Lltl/bWrFkT/fr1i6997WvRr1+/2LJlS3z729+OAwcOxN13391h7i9+8Ys4ePBgLFy4MMrKyuL73/9+XH/99bFr164On6T88Y9/jEcffTS+/OUvR//+/eNHP/pRzJo1K1588cUYNGhQRERs3749pk2bFlVVVbFs2bI4ceJELF++PAYPHtwub+XKlbFs2bLYunVrTJo06b+e969//WscP348xo0bd9J6nz59YvTo0bF9+/YOHzfdkzmBjpmTjr388ssR8b9lhJ7JnLRXFEXs3bs3jh8/Hs3NzbFkyZLo1atX9/xeVOpLKzmd6nLexIkTi4gofvrTn7a7/6FDh9qtLVy4sOjbt29x+PDhtrV58+ad9BKKlpaWIiKKQYMGFa+99lrb+saNG4uIKB577LG2tdtvv73dniKi6NOnT7Fz5862tWeeeaaIiOLHP/5x29p1111X9O3bt3jppZfa1pqbm4vevXu3y3z7PFu3bm33mP7d+vXri4gonnrqqXa3zZ49u7jgggve8Xi6PnPS8Zz8Jy+j6nnMyenPSVEUxZEjR4qPfexjxfDhw4tjx46d9vF0Leak83OyZ8+eIiLafl144YXFww8/3Klju5oe9zKqiIjy8vKYP39+u/X3v//9bb8/ePBgvPrqq3HVVVfFoUOH4vnnn+8w94YbbojKysq2r6+66qqIiNi1a1eHx9bU1MSIESPavh41alQMGDCg7dgTJ07E5s2bY+bMmTF06NC2+1188cUxffr0dnn19fVRFEWHDfnNN9+MiDjlD11VVFS03U7PY06gY+bknS1atCj+9re/xcqVK6N37x75YgrCnJzKeeedF5s2bYrHHnssli9fHueff363fRfQHjn5H/rQh6JPnz7t1p977rlYunRpbNmyJQ4cOHDSbfv37+8w9yMf+chJX789AK2trad97NvHv33sv/71r3jzzTfj4osvbne/U6111tuDfuTIkXa3HT58+KT/COhZzAl0zJz8d3fffXesXr067rjjjvjMZz6TLJeux5y016dPn6ipqYmIiGuvvTYmT54cV155ZQwZMiSuvfbad51/NumRZeNUT6D37dsXEydOjAEDBsTy5ctjxIgRUVFREU8//XTU1dXFW2+91WFur169TrledOLdhd/Nse9GVVVVRETs2bOn3W179uw5qc3Ts5gT6Jg5ObU1a9ZEXV1dfPGLX4ylS5e+Z+fl7GROOjZ+/PioqqqKBx98UNnorhoaGmLv3r3x6KOPxoQJE9rWW1paSrir/zdkyJCoqKiInTt3trvtVGudNXLkyOjdu3ds27Yt5syZ07Z+9OjRaGpqOmkNeuqcwOno6XOycePG+MIXvhDXX3993Hfffe86j+6pp8/JqRw+fLhTV3S6mh75Mxun8nbD/fdGe/To0fjJT35Sqi2dpFevXlFTUxMbNmyIf/7zn23rO3fuPOUH73X2LdjOPffcqKmpiXXr1sXBgwfb1teuXRuvv/56zJ49O92DoMvrqXMCp6Mnz8lTTz0VN954Y0yYMCEefPDBOOccTzM4tZ46J2+88cYp7/OrX/0qWltb2707aHfgysb/GT9+fFRWVsa8efPi1ltvjbKysli7du1Z9fKM+vr6ePLJJ+PKK6+ML33pS3HixIlYuXJljBw5Mpqamk667+m8Bdt3v/vdGD9+fEycOLHtE8R/8IMfxNSpU2PatGn5HhBdTk+ek7/85S/xm9/8JiL+95vN/v374zvf+U5ERFx++eVx3XXX5Xg4dEE9dU5eeOGF+OxnPxtlZWXx+c9/PtavX3/S7aNGjYpRo0ZleDR0RT11Tpqbm6OmpiZuuOGGuPTSS+Occ86Jbdu2xbp162LYsGHxla98Je+DKgFl4/8MGjQofvvb38bXv/71WLp0aVRWVsbcuXNj8uTJ8elPf7rU24uIiLFjx8bjjz8e3/jGN+K2226LD3/4w7F8+fLYsWNHp9614b8ZM2ZMbN68Oerq6uKrX/1q9O/fP26++eb43ve+l3D3dAc9eU6efvrpuO22205ae/vrefPmKRu06alz0tLS0vYSkFtuuaXd7bfffruyQZueOicXXnhhzJo1K7Zs2RIPPPBAHDt2LC666KJYtGhRfOtb32r7jI/upKw4myokZ2TmzJnx3HPPRXNzc6m3AmctcwIdMyfQMXNyeryYsov5z8+9aG5ujt/97nc+JwD+jTmBjpkT6Jg5efdc2ehiqqqqora2Nqqrq+OFF16I+++/P44cORLbt2+PSy65pNTbg7OCOYGOmRPomDl59/zMRhczbdq0eOihh+Lll1+O8vLyuOKKK+LOO+/0Dx7+jTmBjpkT6Jg5efdc2QAAALLwMxsAAEAWygYAAJCFsgEAAGTR7X5A/D8/sTSFurq65JlTpkxJnhkRcddddyXPrKysTJ5J95PjbQD37duXPDMiYtmyZckzZ8yYkTyT7qehoSF55syZM5NnRkSMHj06eWaOx0/prVixInnmkiVLkmcOHz48eWZERGNjY/LM7vTcy5UNAAAgC2UDAADIQtkAAACyUDYAAIAslA0AACALZQMAAMhC2QAAALJQNgAAgCyUDQAAIAtlAwAAy
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 1000x300 with 4 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"digits = datasets.load_digits()\n",
|
|
|
|
"\n",
|
|
|
|
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
|
|
|
|
"for ax, image, label in zip(axes, digits.images, digits.target):\n",
|
|
|
|
" ax.set_axis_off()\n",
|
|
|
|
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
|
|
|
|
" ax.set_title(\"Training: %i\" % label)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Classification\n",
|
|
|
|
"\n",
|
|
|
|
"To apply a classifier on this data, we need to flatten the images, turning\n",
|
|
|
|
"each 2-D array of grayscale values from shape ``(8, 8)`` into shape\n",
|
|
|
|
"``(64,)``. Subsequently, the entire dataset will be of shape\n",
|
|
|
|
"``(n_samples, n_features)``, where ``n_samples`` is the number of images and\n",
|
|
|
|
"``n_features`` is the total number of pixels in each image.\n",
|
|
|
|
"\n",
|
|
|
|
"We can then split the data into train and test subsets and fit a support\n",
|
|
|
|
"vector classifier on the train samples. The fitted classifier can\n",
|
|
|
|
"subsequently be used to predict the value of the digit for the samples\n",
|
|
|
|
"in the test subset.\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 4,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# flatten the images\n",
|
|
|
|
"n_samples = len(digits.images)\n",
|
|
|
|
"data = digits.images.reshape((n_samples, -1))\n",
|
|
|
|
"\n",
|
|
|
|
"# Create a classifier: a support vector classifier\n",
|
|
|
|
"clf = svm.SVC(gamma=0.001)\n",
|
|
|
|
"\n",
|
|
|
|
"# Split data into 50% train and 50% test subsets\n",
|
|
|
|
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
|
|
|
" data, digits.target, test_size=0.5, shuffle=False\n",
|
|
|
|
")\n",
|
|
|
|
"\n",
|
|
|
|
"# Learn the digits on the train subset\n",
|
|
|
|
"clf.fit(X_train, y_train)\n",
|
|
|
|
"\n",
|
|
|
|
"# Predict the value of the digit on the test subset\n",
|
|
|
|
"predicted = clf.predict(X_test)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"Below we visualize the first 4 test samples and show their predicted\n",
|
|
|
|
"digit value in the title.\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 5,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUXUlEQVR4nO3dfWxV9f0H8E+hqJQiZVgV0R9lDohz04qbLhlqHagRNXTGTd1DaM0y9iDSmc2HLZFOjU5HTE1lRrOsdYtG5wwlbuqmGTTiHqMtmToW5lqU4IQltC4OlNLz+4O0roK0Zefb2+LrlfQPzr33fb693E/vfd9ze1qUZVkWAAAAORtX6AUAAACHJmUDAABIQtkAAACSUDYAAIAklA0AACAJZQMAAEhC2QAAAJJQNgAAgCSUDQAAIAllY5gqKiqipqam/9/r1q2LoqKiWLduXW77KCoqivr6+tzyYKSZExgaswKDMydj25gqG83NzVFUVNT/dcQRR8ScOXPi6quvjjfeeKPQyxuWJ554Ykw9qH/+85/Hpz71qSgrK4tp06bFOeecE7/61a8KvSz2w5wUjjkZW8xK4e3evTs++tGPRlFRUaxcubLQy2E/zEnh3HPPPXHSSSfF4YcfHjNmzIhrr7023nrrrUIva9iKC72Ag3HzzTfHrFmzYteuXbF+/fq4995744knnogXX3wxSkpKRnQtZ599duzcuTMOO+ywYd3uiSeeiFWrVu33Qb9z584oLh49/zWNjY1xzTXXxEUXXRQ/+MEPYteuXdHc3BwXX3xxPPbYY3HppZcWeonshzkZWeZk7DIrhdPY2BivvvpqoZfBEJiTkXX99dfHnXfeGZdddlksX748Xn755WhsbIyXXnopfv3rXxd6ecMyeu7VYbjwwgvjE5/4REREfOUrX4lp06bFXXfdFWvWrIkrr7xyv7d56623YtKkSbmvZdy4cXHEEUfkmpl33v+qsbExPvnJT8bjjz8eRUVFERFx1VVXxYwZM+KBBx7wImqUMicjy5yMXWalMLZt2xY333xzXH/99XHTTTcVejkMwpyMnNdffz3uuuuu+PKXvxw//elP+7fPmTMnli1bFo8//nhccsklBVzh8Iypj1G9n8985jMREdHR0RERETU1NVFaWhqvvPJKLFq0KCZPnhxf/OIXIyKit7c3Ghoa4uSTT44jjjgijjnmmFi6dGns2LFjQGaWZXHrrbfG8ccfHyUlJXHuuefGSy+9tM++3+9zg3/84x9j0aJFMXXq1Jg0aVKccsopcffdd/evb9WqVRERAw5N9tnf5wbb2triwgsvjCOPPDJKS0tjwYIF8Yc//GHAdfoOdT733HNx7bXXRnl5eUyaNCk++9nPxvbt2wdct7u7OzZu3Bjd3d2D3r9vvvlmHH300QPW2LeOiRMnDnp7Rgdzspc5YTBmZa9Us9LnhhtuiLlz58aXvvSlId+G0cOc7JViTn7/+99HT09PXHHFFQO29/374YcfPuDtR5sxeWTjvV555ZWIiJg2bVr/tp6enrjgggti/vz5sXLlyv5DfEuXLo3m5uaora2Na665Jjo6OuKee+6Jtra2eO6552LChAkREXHTTTfFrbfeGosWLYpFixbFCy+8EOeff3688847g67n6aefjosvvjimT58ey5cvj2OPPTb++te/xi9/+ctYvnx5LF26NLZu3RpPP/10/OxnPxs076WXXoqzzjorjjzyyLjuuutiwoQJcd9990VVVVW0trbGmWeeOeD6y5Yti6lTp8aKFSuis7MzGhoa4uqrr45HHnmk/zqrV6+O2traaGpqGvBLV/tTVVUVv/jFL6KxsTEuueSS2LVrVzQ2NkZ3d3csX7580PUzOpgTc8LQmJW0sxIR8ac//SkeeOCBWL9+/YAXfIwd5iTdnLz99tsREfu8UdV3fz7//PODrn9UycaQpqamLCKyZ555Jtu+fXv22muvZQ8//HA2bdq0bOLEidmWLVuyLMuyJUuWZBGR3XDDDQNu/+yzz2YRkT344IMDtj/11FMDtm/bti077LDDsosuuijr7e3tv953v/vdLCKyJUuW9G9bu3ZtFhHZ2rVrsyzLsp6enmzWrFnZzJkzsx07dgzYz39nffOb38ze7+6PiGzFihX9/66urs4OO+yw7JVXXunftnXr1mzy5MnZ2Wefvc/9s3DhwgH7+ta3vpWNHz8+6+rq2ue6TU1N+13Df3vjjTeyBQsWZBHR/3XUUUdlv/vd7wa9LSPPnJgThsasFGZWent7szPOOCO78sorsyzLso6Ojiwish/+8IeD3paRZ05Gfk6ef/75LCKyW265ZcD2vvustLT0gLcfbcbkx6gWLlwY5eXlccIJJ8QVV1wRpaWlsXr16pgxY8aA6339618f8O9HH300pkyZEuedd17861//6v86/fTTo7S0NNauXRsREc8880y88847sWzZsgHvuNTV1Q26tra2tujo6Ii6urooKysbcNnBvHuzZ8+e+M1vfhPV1dXx4Q9/uH/79OnT4wtf+EKsX78+3nzzzQG3+epXvzpgX2eddVbs2bMnNm/e3L+tpqYmsiwb0jtQJSUlMXfu3FiyZEk8+uij8ZOf/CSmT58el156afz9738f9vfEyDAn5oShMSsjOyvNzc3xl7/8Je64445hr5/CMScjNyfz5s2LM888M+64445oamqKzs7OePLJJ2Pp0qUxYcKE2Llz57C/p0Iakx+jWrVqVcyZMyeKi4vjmGOOiblz58a4cQN7U3FxcRx//PEDtm3atCm6u7vj6KOP3m/utm3bIiL6HxizZ88ecHl5eXlMnTr1gGvrO6z4sY99bOjf0AFs3749/vOf/8TcuXP3ueykk06K3t7eeO211+Lkk0/u3/5///d/A67Xt+b3fjZyqD73uc9FcXFxPP744/3bFi9eHLNnz47vfe97Aw4RMnqYk73MCYMxK3uNxKy8+eabceONN8Z3vvOdOOGEE4Z9ewrHnOw1Us8pjz32WFx++eVx1VVXRUTE+PHj49prr43W1tb429/+dlCZhTImy8YZZ5zRf0aE93P44YfvMwS9vb1x9NFHx4MPPrjf25SXl+e2xkIaP378frdnWTbsrH/84x/x1FNPxf333z9g+4c+9KGYP39+PPfccwe1RtIzJwdmTuhjVg4sz1lZuXJlvPPOO3H55ZdHZ2dnRERs2bIlIva+KOvs7Izjjjtu2Kc0JT1zcmB5zklExIwZM2L9+vWxadOm+Oc//xmzZ8+OY489No477riYM2fO/7LUETcmy8bBOvHEE+OZZ56JT3/60wc8O8zMmTMjYm8b/+/DZ9u3bx+0oZ544okREfHiiy/GwoUL3/d6Qz2sV15eHiUlJfttsRs3boxx48YlfXeo7w/27NmzZ5/Ldu/eHT09Pcn2TWGYk+EzJx9MZmX4Xn311dixY8eAd4T73HbbbXHbbbdFW1tbVFZWJlsDI8uc/G9mz57df7Tn5Zdfjtdff31IH1ccTcbk72wcrM9//vOxZ8+euOWWW/a5rKenJ7q6uiJi7+cSJ0yYEI2NjQMaaUNDw6D7mDdvXsyaNSsaGhr68/r8d1bfeaffe533Gj9+fJx//vmxZs2a/neBIva+uHnooYdi/vz5ceSRRw66rvca6unXPvKRj8S4cePikUceGbD+LVu2xLPPPhunnXbasPfN6GZO3mVOOBCz8q6hzso111wTq1evHvB13333RcTez7OvXr06Zs2aNez9M3qZk3cdzCmi+/T29sZ1110XJSUl8bWvfW3Yty+kD9SRjXPOOSeWLl0at99+e7S3t8f5558fEyZMiE2bNsWjjz4ad999d1x22WVRXl4e3/72t+P222+Piy++OBYtWhRtbW3x5JNPxlFHHXXAfYwbNy7uvffeuOSSS6KysjJqa2tj+vTpsXHjxgF/9fH000+PiL0/eC+44IIYP378PudT7nPrrbfG008/HfPnz49vfOMbUVxcHPfdd1+8/fbbceeddx7UfTHU06+Vl5fHVVddFT/+8Y9jwYIFcemll8a///3v+NGPfhQ7d+6MG2+88aD2z+hlT
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 1000x300 with 4 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
|
|
|
|
"for ax, image, prediction in zip(axes, X_test, predicted):\n",
|
|
|
|
" ax.set_axis_off()\n",
|
|
|
|
" image = image.reshape(8, 8)\n",
|
|
|
|
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
|
|
|
|
" ax.set_title(f\"Prediction: {prediction}\")"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
":func:`~sklearn.metrics.classification_report` builds a text report showing\n",
|
|
|
|
"the main classification metrics.\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 6,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Classification report for classifier SVC(gamma=0.001):\n",
|
|
|
|
" precision recall f1-score support\n",
|
|
|
|
"\n",
|
|
|
|
" 0 1.00 0.99 0.99 88\n",
|
|
|
|
" 1 0.99 0.97 0.98 91\n",
|
|
|
|
" 2 0.99 0.99 0.99 86\n",
|
|
|
|
" 3 0.98 0.87 0.92 91\n",
|
|
|
|
" 4 0.99 0.96 0.97 92\n",
|
|
|
|
" 5 0.95 0.97 0.96 91\n",
|
|
|
|
" 6 0.99 0.99 0.99 91\n",
|
|
|
|
" 7 0.96 0.99 0.97 89\n",
|
|
|
|
" 8 0.94 1.00 0.97 88\n",
|
|
|
|
" 9 0.93 0.98 0.95 92\n",
|
|
|
|
"\n",
|
|
|
|
" accuracy 0.97 899\n",
|
|
|
|
" macro avg 0.97 0.97 0.97 899\n",
|
|
|
|
"weighted avg 0.97 0.97 0.97 899\n",
|
|
|
|
"\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"print(\n",
|
|
|
|
" f\"Classification report for classifier {clf}:\\n\"\n",
|
|
|
|
" f\"{metrics.classification_report(y_test, predicted)}\\n\"\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"We can also plot a `confusion matrix <confusion_matrix>` of the\n",
|
|
|
|
"true digit values and the predicted digit values.\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 7,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Confusion matrix:\n",
|
|
|
|
"[[87 0 0 0 1 0 0 0 0 0]\n",
|
|
|
|
" [ 0 88 1 0 0 0 0 0 1 1]\n",
|
|
|
|
" [ 0 0 85 1 0 0 0 0 0 0]\n",
|
|
|
|
" [ 0 0 0 79 0 3 0 4 5 0]\n",
|
|
|
|
" [ 0 0 0 0 88 0 0 0 0 4]\n",
|
|
|
|
" [ 0 0 0 0 0 88 1 0 0 2]\n",
|
|
|
|
" [ 0 1 0 0 0 0 90 0 0 0]\n",
|
|
|
|
" [ 0 0 0 0 0 1 0 88 0 0]\n",
|
|
|
|
" [ 0 0 0 0 0 0 0 0 88 0]\n",
|
|
|
|
" [ 0 0 0 1 0 1 0 0 0 90]]\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfIAAAHgCAYAAABej+9AAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB1hklEQVR4nO3deVhUZf8G8HtYh3VYZFU2d3EXzXC3SF5fNbfUzAr3t8JyybXczVD75b5lGS5JZqWWLS5ZbomKKGaiiOKCIpsCwzrAzPn9QYyOaDIMzJlh7s91navmzHPOuTmDfOd5ziYRBEEAERERGSUzsQMQERFR1bGQExERGTEWciIiIiPGQk5ERGTEWMiJiIiMGAs5ERGREWMhJyIiMmIs5EREREaMhZyIiMiIsZATiSwxMRG9evWCTCaDRCLB3r17q3X9N2/ehEQiwZYtW6p1vcasR48e6NGjh9gxiKoFCzkRgOvXr+N///sf6tevD6lUCkdHR3Tu3BmrVq1CYWFhjW47LCwMFy9exOLFi7F9+3a0b9++RrenTyNHjoREIoGjo+MT92NiYiIkEgkkEgn+7//+T+v1p6SkYP78+YiLi6uGtETGyULsAERi+/nnnzFkyBBYW1vjzTffRIsWLVBcXIwTJ05g2rRpuHTpEjZt2lQj2y4sLER0dDQ+/PBDTJgwoUa24efnh8LCQlhaWtbI+p/FwsICBQUF2LdvH4YOHarx3o4dOyCVSlFUVFSldaekpGDBggXw9/dHmzZtKr3cwYMHq7Q9IkPEQk4m7caNG3j11Vfh5+eH33//HV5eXur3wsPDce3aNfz88881tv2MjAwAgJOTU41tQyKRQCqV1tj6n8Xa2hqdO3fG119/XaGQR0VFoU+fPvj+++/1kqWgoAC2trawsrLSy/aI9IFD62TSli1bhry8PGzevFmjiJdr2LAhJk6cqH5dWlqKRYsWoUGDBrC2toa/vz8++OADKBQKjeX8/f3Rt29fnDhxAs899xykUinq16+Pbdu2qdvMnz8ffn5+AIBp06ZBIpHA398fQNmQdPn/P2r+/PmQSCQa8w4dOoQuXbrAyckJ9vb2aNKkCT744AP1+087Rv7777+ja9eusLOzg5OTE/r374/Lly8/cXvXrl3DyJEj4eTkBJlMhlGjRqGgoODpO/Yxr732Gn799VdkZ2er58XExCAxMRGvvfZahfYPHjzA1KlT0bJlS9jb28PR0RG9e/fGhQsX1G2OHDmCDh06AABGjRqlHqIv/zl79OiBFi1aIDY2Ft26dYOtra16vzx+jDwsLAxSqbTCzx8aGgpnZ2ekpKRU+mcl0jcWcjJp+/btQ/369dGpU6dKtR87dizmzp2Ldu3aYcWKFejevTsiIiLw6quvVmh77do1vPLKK3jppZfw6aefwtnZGSNHjsSlS5cAAIMGDcKKFSsAAMOHD8f27duxcuVKrfJfunQJffv2hUKhwMKFC/Hpp5/i5Zdfxp9//vmvy/32228IDQ1Feno65s+fjylTpuDkyZPo3Lkzbt68WaH90KFDkZubi4iICAwdOhRbtmzBggULKp1z0KBBkEgk2L17t3peVFQUmjZtinbt2lVon5SUhL1796Jv375Yvnw5pk2bhosXL6J79+7qotqsWTMsXLgQADB+/Hhs374d27dvR7du3dTruX//Pnr37o02bdpg5cqV6Nmz5xPzrVq1Cm5ubggLC4NSqQQAfPbZZzh48CDWrFkDb2/vSv+sRHonEJmonJwcAYDQv3//SrWPi4sTAAhjx47VmD916lQBgPD777+r5/n5+QkAhGPHjqnnpaenC9bW1sL777+vnnfjxg0BgPDJJ59orDMsLEzw8/OrkGHevHnCo/9sV6xYIQAQMjIynpq7fBuRkZHqeW3atBHc3d2F+/fvq+dduHBBMDMzE958880K2xs9erTGOgcOHCi4uro+dZuP/hx2dnaCIAjCK6+8Irz44ouCIAiCUqkUPD09hQULFjxxHxQVFQlKpbLCz2FtbS0sXLhQPS8mJqbCz1aue/fuAgBh48aNT3yve/fuGvMOHDggABA++ugjISkpSbC3txcGDBjwzJ+RSGzskZPJksvlAAAHB4dKtf/ll18AAFOmTNGY//777wNAhWPpgYGB6Nq1q/q1m5sbmjRpgqSkpCpnflz5sfUffvgBKpWqUsvcu3cPcXFxGDlyJFxcXNTzW7VqhZdeekn9cz7qrbfe0njdtWtX3L9/X70PK+O1117DkSNHkJqait9//x2pqalPHFYHyo6rm5mV/XlSKpW4f/+++rDBuXPnKr1Na2trjBo1qlJte/Xqhf/9739YuHAhBg0aBKlUis8++6zS2yISCws5mSxHR0cAQG5ubqXa37p1C2ZmZmjYsKHGfE9PTzg5OeHWrVsa8319fSusw9nZGVlZWVVMXNGwYcPQuXNnjB07Fh4eHnj11Vexa9eufy3q5TmbNGlS4b1mzZohMzMT+fn5GvMf/1mcnZ0BQKuf5b///S8cHBzwzTffYMeOHejQoUOFfVlOpVJhxYoVaNSoEaytrVGnTh24ubnhr7/+Qk5OTqW3WbduXa1ObPu///s/uLi4IC4uDqtXr4a7u3ullyUSCws5mSxHR0d4e3vj77//1mq5x082expzc/MnzhcEocrbKD9+W87GxgbHjh3Db7/9hjfeeAN//fUXhg0bhpdeeqlCW13o8rOUs7a2xqBBg7B161bs2bPnqb1xAPj4448xZcoUdOvWDV999RUOHDiAQ4cOoXnz5pUeeQDK9o82zp8/j/T0dADAxYsXtVqWSCws5GTS+vbti+vXryM6OvqZbf38/KBSqZCYmKgxPy0tDdnZ2eoz0KuDs7Ozxhne5R7v9QOAmZkZXnzxRSxfvhzx8fFYvHgxfv/9d/zxxx9PXHd5zoSEhArvXblyBXXq1IGdnZ1uP8BTvPbaazh//jxyc3OfeIJgue+++w49e/bE5s2b8eqrr6JXr14ICQmpsE8q+6WqMvLz8zFq1CgEBgZi/PjxWLZsGWJiYqpt/UQ1hYWcTNr06dNhZ2eHsWPHIi0trcL7169fx6pVqwCUDQ0DqHBm+fLlywEAffr0qbZcDRo0QE5ODv766y/1vHv37mHPnj0a7R48eFBh2fIbozx+SVw5Ly8vtGnTBlu3btUojH///TcOHjyo/jlrQs+ePbFo0SKsXbsWnp6eT21nbm5eobf/7bff4u7duxrzyr9wPOlLj7ZmzJiB27dvY+vWrVi+fDn8/f0RFhb21P1IZCh4QxgyaQ0aNEBUVBSGDRuGZs2aadzZ7eTJk/j2228xcuRIAEDr1q0RFhaGTZs2ITs7G927d8eZM2ewdetWDBgw4KmXNlXFq6++ihkzZmDgwIF47733UFBQgA0bNqBx48YaJ3stXLgQx44dQ58+feDn54f09HSsX78e9erVQ5cuXZ66/k8++QS9e/dGcHAwxowZg8LCQqxZswYymQzz58+vtp/jcWZmZpg9e/Yz2/Xt2xcLFy7EqFGj0KlTJ1y8eBE7duxA/fr1Ndo1aNAATk5O2LhxIxwcHGBnZ4eOHTsiICBAq1y///471q9fj3nz5qkvh4uMjESPHj0wZ84cLFu2TKv1EemVyGfNExmEq1evCuPGjRP8/f0FKysrwcHBQejcubOwZs0aoaioSN2upKREWLBggRAQECBYWloKPj4+wqxZszTaCELZ5Wd9+vSpsJ3HL3t62uVngiAIBw8eFFq0aCFYWVkJTZo0Eb766qsKl58dPnxY6N+/v+Dt7S1YWVkJ3t7ewvDhw4WrV69W2Mbjl2j99ttvQufOnQUbGxvB0dFR6NevnxAfH6/Rpnx7j1/eFhkZKQAQbty48dR9Kgial589zdMuP3v//fcFLy8vwcbGRujcubMQHR39xMvGfvjhByEwMFCwsLDQ+Dm7d+8uNG/e/InbfHQ9crlc8PPzE9q1ayeUlJRotJs8ebJgZmYmREdH/+vPQCQmiSBocbYKERERGRQeIyciIjJiLORERERGjIWciIjIiLGQExERGTEWciIiIiPGQk5ERGTEWMiJiIiMG
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 640x480 with 2 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)\n",
|
|
|
|
"disp.figure_.suptitle(\"Confusion Matrix\")\n",
|
|
|
|
"print(f\"Confusion matrix:\\n{disp.confusion_matrix}\")\n",
|
|
|
|
"\n",
|
|
|
|
"plt.show()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"If the results from evaluating a classifier are stored in the form of a\n",
|
|
|
|
"`confusion matrix <confusion_matrix>` and not in terms of `y_true` and\n",
|
|
|
|
"`y_pred`, one can still build a :func:`~sklearn.metrics.classification_report`\n",
|
|
|
|
"as follows:\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 8,
|
|
|
|
"metadata": {
|
|
|
|
"collapsed": false
|
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Classification report rebuilt from confusion matrix:\n",
|
|
|
|
" precision recall f1-score support\n",
|
|
|
|
"\n",
|
|
|
|
" 0 1.00 0.99 0.99 88\n",
|
|
|
|
" 1 0.99 0.97 0.98 91\n",
|
|
|
|
" 2 0.99 0.99 0.99 86\n",
|
|
|
|
" 3 0.98 0.87 0.92 91\n",
|
|
|
|
" 4 0.99 0.96 0.97 92\n",
|
|
|
|
" 5 0.95 0.97 0.96 91\n",
|
|
|
|
" 6 0.99 0.99 0.99 91\n",
|
|
|
|
" 7 0.96 0.99 0.97 89\n",
|
|
|
|
" 8 0.94 1.00 0.97 88\n",
|
|
|
|
" 9 0.93 0.98 0.95 92\n",
|
|
|
|
"\n",
|
|
|
|
" accuracy 0.97 899\n",
|
|
|
|
" macro avg 0.97 0.97 0.97 899\n",
|
|
|
|
"weighted avg 0.97 0.97 0.97 899\n",
|
|
|
|
"\n",
|
|
|
|
"\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"# The ground truth and predicted lists\n",
|
|
|
|
"y_true = []\n",
|
|
|
|
"y_pred = []\n",
|
|
|
|
"cm = disp.confusion_matrix\n",
|
|
|
|
"\n",
|
|
|
|
"# For each cell in the confusion matrix, add the corresponding ground truths\n",
|
|
|
|
"# and predictions to the lists\n",
|
|
|
|
"for gt in range(len(cm)):\n",
|
|
|
|
" for pred in range(len(cm)):\n",
|
|
|
|
" y_true += [gt] * cm[gt][pred]\n",
|
|
|
|
" y_pred += [pred] * cm[gt][pred]\n",
|
|
|
|
"\n",
|
|
|
|
"print(\n",
|
|
|
|
" \"Classification report rebuilt from confusion matrix:\\n\"\n",
|
|
|
|
" f\"{metrics.classification_report(y_true, y_pred)}\\n\"\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
2023-01-30 11:26:08 +01:00
|
|
|
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
|
2022-12-29 17:27:20 +01:00
|
|
|
},
|
|
|
|
"vscode": {
|
|
|
|
"interpreter": {
|
|
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 0
|
|
|
|
}
|