M2_SETI/D3/TP/TP_SETI_SVM/plot_digits_classification.ipynb

364 lines
64 KiB
Text
Raw Normal View History

2022-12-29 17:27:20 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# Recognizing hand-written digits\n",
"\n",
"This example shows how scikit-learn can be used to recognize images of\n",
"hand-written digits, from 0-9.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>\n",
"# License: BSD 3 clause\n",
"\n",
"# Standard scientific Python imports\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Import datasets, classifiers and performance metrics\n",
"from sklearn import datasets, svm, metrics\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Digits dataset\n",
"\n",
"The digits dataset consists of 8x8\n",
"pixel images of digits. The ``images`` attribute of the dataset stores\n",
"8x8 arrays of grayscale values for each image. We will use these arrays to\n",
"visualize the first 4 images. The ``target`` attribute of the dataset stores\n",
"the digit each image represents and this is included in the title of the 4\n",
"plots below.\n",
"\n",
"Note: if we were working from image files (e.g., 'png' files), we would load\n",
"them using :func:`matplotlib.pyplot.imread`.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASFklEQVR4nO3db5CVZd0H8N8KsRsBsiLkkiUsOmPJIAHNJCbgsBCkBkmgLxhZxgYqGaM/M8sU5oJlkjZjhRnxBgNzlDLIJlMY2JymN7GyloYzSyw6GU6Kyx9F/no/L57HfaIld8Hr8rC7n88MM+x1zv29rwP82POd++w5ZUVRFAEAAJDYOaXeAAAA0D0pGwAAQBbKBgAAkIWyAQAAZKFsAAAAWSgbAABAFsoGAACQhbIBAABkoWwAAABZKBtnoLa2NoYNG3ZGx9bX10dZWVnaDcFZyJxAx8wJdMycdG3dqmyUlZV16ldDQ0Opt3rW+dOf/hSf+tSnom/fvnHBBRfErbfeGq+//nqpt0UG5uTMPPnkk3HzzTfHyJEjo1evXmf8jY+uwZycvkOHDsV9990XU6dOjaqqqujfv398/OMfj/vvvz9OnDhR6u2RgTk5M3feeWd88pOfjMGDB0dFRUVccsklsXjx4njllVdKvbUsyoqiKEq9iVTWrVt30tc///nPY9OmTbF27dqT1qdMmRIf/OAHz/g8x44di7feeivKy8tP+9jjx4/H8ePHo6Ki4ozPn1pTU1NcccUV8dGPfjQWLFgQ//jHP+Kee+6Jq6++Oh5//PFSb4/EzMmZqa2tjYcffjjGjBkTL774YvTq1St2795d6m2RiTk5fc8++2yMGjUqJk+eHFOnTo0BAwbEE088Eb/+9a/jpptuigceeKDUWyQxc3JmZs2aFYMHD45LL700+vfvHzt27IjVq1fHkCFDoqmpKT7wgQ+UeotpFd3YLbfcUnTmIb7xxhvvwW7OXtOnTy+qqqqK/fv3t62tXr26iIjiiSeeKOHOeC+Yk8556aWXiqNHjxZFURTXXHNNcdFFF5V2Q7ynzEnHXnnlleLZZ59ttz5//vwiIorm5uYS7Ir3kjk5c7/85S+LiCgeeuihUm8luW71MqrOmDRpUowcOTIaGxtjwoQJ0bdv3/jmN78ZEREbN26Ma665JoYOHRrl5eUxYsSIuOOOO9pd/v3P1w7u3r07ysrK4p577omf/exnMWLEiCgvL49PfOIT8ec///mkY0/12sGysrJYtGhRbNiwIUaOHBnl5eVx2WWXxe9///t2+29oaIhx48ZFRUVFjBgxIlatWnXKzFdffTWef/75OHTo0Dv+eRw4cCA2bdoUc+fOjQEDBrSt33TTTdGvX7945JFH3vF4uidz0t7QoUPjfe97X4f3o+cwJyc7//zz47LLLmu3/rnPfS4iInbs2PGOx9M9mZPOefvx7du374yOP5v1LvUGSmHv3r0xffr0uPHGG2Pu3Lltl/bWrFkT/fr1i6997WvRr1+/2LJlS3z729+OAwcOxN13391h7i9+8Ys4ePBgLFy4MMrKyuL73/9+XH/99bFr164On6T88Y9/jEcffTS+/OUvR//+/eNHP/pRzJo1K1588cUYNGhQRERs3749pk2bFlVVVbFs2bI4ceJELF++PAYPHtwub+XKlbFs2bLYunVrTJo06b+e969//WscP348xo0bd9J6nz59YvTo0bF9+/YOHzfdkzmBjpmTjr388ssR8b9lhJ7JnLRXFEXs3bs3jh8/Hs3NzbFkyZLo1atX9/xeVOpLKzmd6nLexIkTi4gofvrTn7a7/6FDh9qtLVy4sOjbt29x+PDhtrV58+ad9BKKlpaWIiKKQYMGFa+99lrb+saNG4uIKB577LG2tdtvv73dniKi6NOnT7Fz5862tWeeeaaIiOLHP/5x29p1111X9O3bt3jppZfa1pqbm4vevXu3y3z7PFu3bm33mP7d+vXri4gonnrqqXa3zZ49u7jgggve8Xi6PnPS8Zz8Jy+j6nnMyenPSVEUxZEjR4qPfexjxfDhw4tjx46d9vF0Leak83OyZ8+eIiLafl144YXFww8/3Klju5oe9zKqiIjy8vKYP39+u/X3v//9bb8/ePBgvPrqq3HVVVfFoUOH4vnnn+8w94YbbojKysq2r6+66qqIiNi1a1eHx9bU1MSIESPavh41alQMGDCg7dgTJ07E5s2bY+bMmTF06NC2+1188cUxffr0dnn19fVRFEWHDfnNN9+MiDjlD11VVFS03U7PY06gY+bknS1atCj+9re/xcqVK6N37x75YgrCnJzKeeedF5s2bYrHHnssli9fHueff363fRfQHjn5H/rQh6JPnz7t1p977rlYunRpbNmyJQ4cOHDSbfv37+8w9yMf+chJX789AK2trad97NvHv33sv/71r3jzzTfj4osvbne/U6111tuDfuTIkXa3HT58+KT/COhZzAl0zJz8d3fffXesXr067rjjjvjMZz6TLJeux5y016dPn6ipqYmIiGuvvTYmT54cV155ZQwZMiSuvfbad51/NumRZeNUT6D37dsXEydOjAEDBsTy5ctjxIgRUVFREU8//XTU1dXFW2+91WFur169TrledOLdhd/Nse9GVVVVRETs2bOn3W179uw5qc3Ts5gT6Jg5ObU1a9ZEXV1dfPGLX4ylS5e+Z+fl7GROOjZ+/PioqqqKBx98UNnorhoaGmLv3r3x6KOPxoQJE9rWW1paSrir/zdkyJCoqKiInTt3trvtVGudNXLkyOjdu3ds27Yt5syZ07Z+9OjRaGpqOmkNeuqcwOno6XOycePG+MIXvhDXX3993Hfffe86j+6pp8/JqRw+fLhTV3S6mh75Mxun8nbD/fdGe/To0fjJT35Sqi2dpFevXlFTUxMbNmyIf/7zn23rO3fuPOUH73X2LdjOPffcqKmpiXXr1sXBgwfb1teuXRuvv/56zJ49O92DoMvrqXMCp6Mnz8lTTz0VN954Y0yYMCEefPDBOOccTzM4tZ46J2+88cYp7/OrX/0qWltb2707aHfgysb/GT9+fFRWVsa8efPi1ltvjbKysli7du1Z9fKM+vr6ePLJJ+PKK6+ML33pS3HixIlYuXJljBw5Mpqamk667+m8Bdt3v/vdGD9+fEycOLHtE8R/8IMfxNSpU2PatGn5HhBdTk+ek7/85S/xm9/8JiL+95vN/v374zvf+U5ERFx++eVx3XXX5Xg4dEE9dU5eeOGF+OxnPxtlZWXx+c9/PtavX3/S7aNGjYpRo0ZleDR0RT11Tpqbm6OmpiZuuOGGuPTSS+Occ86Jbdu2xbp162LYsGHxla98Je+DKgFl4/8MGjQofvvb38bXv/71WLp0aVRWVsbcuXNj8uTJ8elPf7rU24uIiLFjx8bjjz8e3/jGN+K2226LD3/4w7F8+fLYsWNHp9614b8ZM2ZMbN68Oerq6uKrX/1q9O/fP26++eb43ve+l3D3dAc9eU6efvrpuO22205ae/vrefPmKRu06alz0tLS0vYSkFtuuaXd7bfffruyQZueOicXXnhhzJo1K7Zs2RIPPPBAHDt2LC666KJYtGhRfOtb32r7jI/upKw4myokZ2TmzJnx3HPPRXNzc6m3AmctcwIdMyfQMXNyeryYsov5z8+9aG5ujt/97nc+JwD+jTmBjpkT6Jg5efdc2ehiqqqqora2Nqqrq+OFF16I+++/P44cORLbt2+PSy65pNTbg7OCOYGOmRPomDl59/zMRhczbdq0eOihh+Lll1+O8vLyuOKKK+LOO+/0Dx7+jTmBjpkT6Jg5efdc2QAAALLwMxsAAEAWygYAAJCFsgEAAGTR7X5A/D8/sTSFurq65JlTpkxJnhkRcddddyXPrKysTJ5J95PjbQD37duXPDMiYtmyZckzZ8yYkTyT7qehoSF55syZM5NnRkSMHj06eWaOx0/prVixInnmkiVLkmcOHz48eWZERGNjY/LM7vTcy5UNAAAgC2UDAADIQtkAAACyUDYAAIAslA0AACALZQMAAMhC2QAAALJQNgAAgCyUDQAAIAtlAwAAy
"text/plain": [
"<Figure size 1000x300 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"digits = datasets.load_digits()\n",
"\n",
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
"for ax, image, label in zip(axes, digits.images, digits.target):\n",
" ax.set_axis_off()\n",
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
" ax.set_title(\"Training: %i\" % label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification\n",
"\n",
"To apply a classifier on this data, we need to flatten the images, turning\n",
"each 2-D array of grayscale values from shape ``(8, 8)`` into shape\n",
"``(64,)``. Subsequently, the entire dataset will be of shape\n",
"``(n_samples, n_features)``, where ``n_samples`` is the number of images and\n",
"``n_features`` is the total number of pixels in each image.\n",
"\n",
"We can then split the data into train and test subsets and fit a support\n",
"vector classifier on the train samples. The fitted classifier can\n",
"subsequently be used to predict the value of the digit for the samples\n",
"in the test subset.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# flatten the images\n",
"n_samples = len(digits.images)\n",
"data = digits.images.reshape((n_samples, -1))\n",
"\n",
"# Create a classifier: a support vector classifier\n",
"clf = svm.SVC(gamma=0.001)\n",
"\n",
"# Split data into 50% train and 50% test subsets\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" data, digits.target, test_size=0.5, shuffle=False\n",
")\n",
"\n",
"# Learn the digits on the train subset\n",
"clf.fit(X_train, y_train)\n",
"\n",
"# Predict the value of the digit on the test subset\n",
"predicted = clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below we visualize the first 4 test samples and show their predicted\n",
"digit value in the title.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUXUlEQVR4nO3dfWxV9f0H8E+hqJQiZVgV0R9lDohz04qbLhlqHagRNXTGTd1DaM0y9iDSmc2HLZFOjU5HTE1lRrOsdYtG5wwlbuqmGTTiHqMtmToW5lqU4IQltC4OlNLz+4O0roK0Zefb2+LrlfQPzr33fb693E/vfd9ze1qUZVkWAAAAORtX6AUAAACHJmUDAABIQtkAAACSUDYAAIAklA0AACAJZQMAAEhC2QAAAJJQNgAAgCSUDQAAIAllY5gqKiqipqam/9/r1q2LoqKiWLduXW77KCoqivr6+tzyYKSZExgaswKDMydj25gqG83NzVFUVNT/dcQRR8ScOXPi6quvjjfeeKPQyxuWJ554Ykw9qH/+85/Hpz71qSgrK4tp06bFOeecE7/61a8KvSz2w5wUjjkZW8xK4e3evTs++tGPRlFRUaxcubLQy2E/zEnh3HPPPXHSSSfF4YcfHjNmzIhrr7023nrrrUIva9iKC72Ag3HzzTfHrFmzYteuXbF+/fq4995744knnogXX3wxSkpKRnQtZ599duzcuTMOO+ywYd3uiSeeiFWrVu33Qb9z584oLh49/zWNjY1xzTXXxEUXXRQ/+MEPYteuXdHc3BwXX3xxPPbYY3HppZcWeonshzkZWeZk7DIrhdPY2BivvvpqoZfBEJiTkXX99dfHnXfeGZdddlksX748Xn755WhsbIyXXnopfv3rXxd6ecMyeu7VYbjwwgvjE5/4REREfOUrX4lp06bFXXfdFWvWrIkrr7xyv7d56623YtKkSbmvZdy4cXHEEUfkmpl33v+qsbExPvnJT8bjjz8eRUVFERFx1VVXxYwZM+KBBx7wImqUMicjy5yMXWalMLZt2xY333xzXH/99XHTTTcVejkMwpyMnNdffz3uuuuu+PKXvxw//elP+7fPmTMnli1bFo8//nhccsklBVzh8Iypj1G9n8985jMREdHR0RERETU1NVFaWhqvvPJKLFq0KCZPnhxf/OIXIyKit7c3Ghoa4uSTT44jjjgijjnmmFi6dGns2LFjQGaWZXHrrbfG8ccfHyUlJXHuuefGSy+9tM++3+9zg3/84x9j0aJFMXXq1Jg0aVKccsopcffdd/evb9WqVRERAw5N9tnf5wbb2triwgsvjCOPPDJKS0tjwYIF8Yc//GHAdfoOdT733HNx7bXXRnl5eUyaNCk++9nPxvbt2wdct7u7OzZu3Bjd3d2D3r9vvvlmHH300QPW2LeOiRMnDnp7Rgdzspc5YTBmZa9Us9LnhhtuiLlz58aXvvSlId+G0cOc7JViTn7/+99HT09PXHHFFQO29/374YcfPuDtR5sxeWTjvV555ZWIiJg2bVr/tp6enrjgggti/vz5sXLlyv5DfEuXLo3m5uaora2Na665Jjo6OuKee+6Jtra2eO6552LChAkREXHTTTfFrbfeGosWLYpFixbFCy+8EOeff3688847g67n6aefjosvvjimT58ey5cvj2OPPTb++te/xi9/+ctYvnx5LF26NLZu3RpPP/10/OxnPxs076WXXoqzzjorjjzyyLjuuutiwoQJcd9990VVVVW0trbGmWeeOeD6y5Yti6lTp8aKFSuis7MzGhoa4uqrr45HHnmk/zqrV6+O2traaGpqGvBLV/tTVVUVv/jFL6KxsTEuueSS2LVrVzQ2NkZ3d3csX7580PUzOpgTc8LQmJW0sxIR8ac//SkeeOCBWL9+/YAXfIwd5iTdnLz99tsREfu8UdV3fz7//PODrn9UycaQpqamLCKyZ555Jtu+fXv22muvZQ8//HA2bdq0bOLEidmWLVuyLMuyJUuWZBGR3XDDDQNu/+yzz2YRkT344IMDtj/11FMDtm/bti077LDDsosuuijr7e3tv953v/vdLCKyJUuW9G9bu3ZtFhHZ2rVrsyzLsp6enmzWrFnZzJkzsx07dgzYz39nffOb38ze7+6PiGzFihX9/66urs4OO+yw7JVXXunftnXr1mzy5MnZ2Wefvc/9s3DhwgH7+ta3vpWNHz8+6+rq2ue6TU1N+13Df3vjjTeyBQsWZBHR/3XUUUdlv/vd7wa9LSPPnJgThsasFGZWent7szPOOCO78sorsyzLso6Ojiwish/+8IeD3paRZ05Gfk6ef/75LCKyW265ZcD2vvustLT0gLcfbcbkx6gWLlwY5eXlccIJJ8QVV1wRpaWlsXr16pgxY8aA6339618f8O9HH300pkyZEuedd17861//6v86/fTTo7S0NNauXRsREc8880y88847sWzZsgHvuNTV1Q26tra2tujo6Ii6urooKysbcNnBvHuzZ8+e+M1vfhPV1dXx4Q9/uH/79OnT4wtf+EKsX78+3nzzzQG3+epXvzpgX2eddVbs2bMnNm/e3L+tpqYmsiwb0jtQJSUlMXfu3FiyZEk8+uij8ZOf/CSmT58el156afz9738f9vfEyDAn5oShMSsjOyvNzc3xl7/8Je64445hr5/CMScjNyfz5s2LM888M+64445oamqKzs7OePLJJ2Pp0qUxYcKE2Llz57C/p0Iakx+jWrVqVcyZMyeKi4vjmGOOiblz58a4cQN7U3FxcRx//PEDtm3atCm6u7vj6KOP3m/utm3bIiL6HxizZ88ecHl5eXlMnTr1gGvrO6z4sY99bOjf0AFs3749/vOf/8TcuXP3ueykk06K3t7eeO211+Lkk0/u3/5///d/A67Xt+b3fjZyqD73uc9FcXFxPP744/3bFi9eHLNnz47vfe97Aw4RMnqYk73MCYMxK3uNxKy8+eabceONN8Z3vvOdOOGEE4Z9ewrHnOw1Us8pjz32WFx++eVx1VVXRUTE+PHj49prr43W1tb429/+dlCZhTImy8YZZ5zRf0aE93P44YfvMwS9vb1x9NFHx4MPPrjf25SXl+e2xkIaP378frdnWTbsrH/84x/x1FNPxf333z9g+4c+9KGYP39+PPfccwe1RtIzJwdmTuhjVg4sz1lZuXJlvPPOO3H55ZdHZ2dnRERs2bIlIva+KOvs7Izjjjtu2Kc0JT1zcmB5zklExIwZM2L9+vWxadOm+Oc//xmzZ8+OY489No477riYM2fO/7LUETcmy8bBOvHEE+OZZ56JT3/60wc8O8zMmTMjYm8b/+/DZ9u3bx+0oZ544okREfHiiy/GwoUL3/d6Qz2sV15eHiUlJfttsRs3boxx48YlfXeo7w/27NmzZ5/Ldu/eHT09Pcn2TWGYk+EzJx9MZmX4Xn311dixY8eAd4T73HbbbXHbbbdFW1tbVFZWJlsDI8uc/G9mz57df7Tn5Zdfjtdff31IH1ccTcbk72wcrM9//vOxZ8+euOWWW/a5rKenJ7q6uiJi7+cSJ0yYEI2NjQMaaUNDw6D7mDdvXsyaNSsaGhr68/r8d1bfeaffe533Gj9+fJx//vmxZs2a/neBIva+uHnooYdi/vz5ceSRRw66rvca6unXPvKRj8S4cePikUceGbD+LVu2xLPPPhunnXbasPfN6GZO3mVOOBCz8q6hzso111wTq1evHvB13333RcTez7OvXr06Zs2aNez9M3qZk3cdzCmi+/T29sZ1110XJSUl8bWvfW3Yty+kD9SRjXPOOSeWLl0at99+e7S3t8f5558fEyZMiE2bNsWjjz4ad999d1x22WVRXl4e3/72t+P222+Piy++OBYtWhRtbW3x5JNPxlFHHXXAfYwbNy7uvffeuOSSS6KysjJqa2tj+vTpsXHjxgF/9fH000+PiL0/eC+44IIYP378PudT7nPrrbfG008/HfPnz49vfOMbUVxcHPfdd1+8/fbbceeddx7UfTHU06+Vl5fHVVddFT/+8Y9jwYIFcemll8a///3v+NGPfhQ7d+6MG2+88aD2z+hlT
"text/plain": [
"<Figure size 1000x300 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
"for ax, image, prediction in zip(axes, X_test, predicted):\n",
" ax.set_axis_off()\n",
" image = image.reshape(8, 8)\n",
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
" ax.set_title(f\"Prediction: {prediction}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":func:`~sklearn.metrics.classification_report` builds a text report showing\n",
"the main classification metrics.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report for classifier SVC(gamma=0.001):\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.99 0.99 88\n",
" 1 0.99 0.97 0.98 91\n",
" 2 0.99 0.99 0.99 86\n",
" 3 0.98 0.87 0.92 91\n",
" 4 0.99 0.96 0.97 92\n",
" 5 0.95 0.97 0.96 91\n",
" 6 0.99 0.99 0.99 91\n",
" 7 0.96 0.99 0.97 89\n",
" 8 0.94 1.00 0.97 88\n",
" 9 0.93 0.98 0.95 92\n",
"\n",
" accuracy 0.97 899\n",
" macro avg 0.97 0.97 0.97 899\n",
"weighted avg 0.97 0.97 0.97 899\n",
"\n",
"\n"
]
}
],
"source": [
"print(\n",
" f\"Classification report for classifier {clf}:\\n\"\n",
" f\"{metrics.classification_report(y_test, predicted)}\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also plot a `confusion matrix <confusion_matrix>` of the\n",
"true digit values and the predicted digit values.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Confusion matrix:\n",
"[[87 0 0 0 1 0 0 0 0 0]\n",
" [ 0 88 1 0 0 0 0 0 1 1]\n",
" [ 0 0 85 1 0 0 0 0 0 0]\n",
" [ 0 0 0 79 0 3 0 4 5 0]\n",
" [ 0 0 0 0 88 0 0 0 0 4]\n",
" [ 0 0 0 0 0 88 1 0 0 2]\n",
" [ 0 1 0 0 0 0 90 0 0 0]\n",
" [ 0 0 0 0 0 1 0 88 0 0]\n",
" [ 0 0 0 0 0 0 0 0 88 0]\n",
" [ 0 0 0 1 0 1 0 0 0 90]]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfIAAAHgCAYAAABej+9AAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB1hklEQVR4nO3deVhUZf8G8HtYh3VYZFU2d3EXzXC3SF5fNbfUzAr3t8JyybXczVD75b5lGS5JZqWWLS5ZbomKKGaiiOKCIpsCwzrAzPn9QYyOaDIMzJlh7s91navmzHPOuTmDfOd5ziYRBEEAERERGSUzsQMQERFR1bGQExERGTEWciIiIiPGQk5ERGTEWMiJiIiMGAs5ERGREWMhJyIiMmIs5EREREaMhZyIiMiIsZATiSwxMRG9evWCTCaDRCLB3r17q3X9N2/ehEQiwZYtW6p1vcasR48e6NGjh9gxiKoFCzkRgOvXr+N///sf6tevD6lUCkdHR3Tu3BmrVq1CYWFhjW47LCwMFy9exOLFi7F9+3a0b9++RrenTyNHjoREIoGjo+MT92NiYiIkEgkkEgn+7//+T+v1p6SkYP78+YiLi6uGtETGyULsAERi+/nnnzFkyBBYW1vjzTffRIsWLVBcXIwTJ05g2rRpuHTpEjZt2lQj2y4sLER0dDQ+/PBDTJgwoUa24efnh8LCQlhaWtbI+p/FwsICBQUF2LdvH4YOHarx3o4dOyCVSlFUVFSldaekpGDBggXw9/dHmzZtKr3cwYMHq7Q9IkPEQk4m7caNG3j11Vfh5+eH33//HV5eXur3wsPDce3aNfz88881tv2MjAwAgJOTU41tQyKRQCqV1tj6n8Xa2hqdO3fG119/XaGQR0VFoU+fPvj+++/1kqWgoAC2trawsrLSy/aI9IFD62TSli1bhry8PGzevFmjiJdr2LAhJk6cqH5dWlqKRYsWoUGDBrC2toa/vz8++OADKBQKjeX8/f3Rt29fnDhxAs899xykUinq16+Pbdu2qdvMnz8ffn5+AIBp06ZBIpHA398fQNmQdPn/P2r+/PmQSCQa8w4dOoQuXbrAyckJ9vb2aNKkCT744AP1+087Rv7777+ja9eusLOzg5OTE/r374/Lly8/cXvXrl3DyJEj4eTkBJlMhlGjRqGgoODpO/Yxr732Gn799VdkZ2er58XExCAxMRGvvfZahfYPHjzA1KlT0bJlS9jb28PR0RG9e/fGhQsX1G2OHDmCDh06AABGjRqlHqIv/zl79OiBFi1aIDY2Ft26dYOtra16vzx+jDwsLAxSqbTCzx8aGgpnZ2ekpKRU+mcl0jcWcjJp+/btQ/369dGpU6dKtR87dizmzp2Ldu3aYcWKFejevTsiIiLw6quvVmh77do1vPLKK3jppZfw6aefwtnZGSNHjsSlS5cAAIMGDcKKFSsAAMOHD8f27duxcuVKrfJfunQJffv2hUKhwMKFC/Hpp5/i5Zdfxp9//vmvy/32228IDQ1Feno65s+fjylTpuDkyZPo3Lkzbt68WaH90KFDkZubi4iICAwdOhRbtmzBggULKp1z0KBBkEgk2L17t3peVFQUmjZtinbt2lVon5SUhL1796Jv375Yvnw5pk2bhosXL6J79+7qotqsWTMsXLgQADB+/Hhs374d27dvR7du3dTruX//Pnr37o02bdpg5cqV6Nmz5xPzrVq1Cm5ubggLC4NSqQQAfPbZZzh48CDWrFkDb2/vSv+sRHonEJmonJwcAYDQv3//SrWPi4sTAAhjx47VmD916lQBgPD777+r5/n5+QkAhGPHjqnnpaenC9bW1sL777+vnnfjxg0BgPDJJ59orDMsLEzw8/OrkGHevHnCo/9sV6xYIQAQMjIynpq7fBuRkZHqeW3atBHc3d2F+/fvq+dduHBBMDMzE958880K2xs9erTGOgcOHCi4uro+dZuP/hx2dnaCIAjCK6+8Irz44ouCIAiCUqkUPD09hQULFjxxHxQVFQlKpbLCz2FtbS0sXLhQPS8mJqbCz1aue/fuAgBh48aNT3yve/fuGvMOHDggABA++ugjISkpSbC3txcGDBjwzJ+RSGzskZPJksvlAAAHB4dKtf/ll18AAFOmTNGY//777wNAhWPpgYGB6Nq1q/q1m5sbmjRpgqSkpCpnflz5sfUffvgBKpWqUsvcu3cPcXFxGDlyJFxcXNTzW7VqhZdeekn9cz7qrbfe0njdtWtX3L9/X70PK+O1117DkSNHkJqait9//x2pqalPHFYHyo6rm5mV/XlSKpW4f/+++rDBuXPnKr1Na2trjBo1qlJte/Xqhf/9739YuHAhBg0aBKlUis8++6zS2yISCws5mSxHR0cAQG5ubqXa37p1C2ZmZmjYsKHGfE9PTzg5OeHWrVsa8319fSusw9nZGVlZWVVMXNGwYcPQuXNnjB07Fh4eHnj11Vexa9eufy3q5TmbNGlS4b1mzZohMzMT+fn5GvMf/1mcnZ0BQKuf5b///S8cHBzwzTffYMeOHejQoUOFfVlOpVJhxYoVaNSoEaytrVGnTh24ubnhr7/+Qk5OTqW3WbduXa1ObPu///s/uLi4IC4uDqtXr4a7u3ullyUSCws5mSxHR0d4e3vj77//1mq5x082expzc/MnzhcEocrbKD9+W87GxgbHjh3Db7/9hjfeeAN//fUXhg0bhpdeeqlCW13o8rOUs7a2xqBBg7B161bs2bPnqb1xAPj4448xZcoUdOvWDV999RUOHDiAQ4cOoXnz5pUeeQDK9o82zp8/j/T0dADAxYsXtVqWSCws5GTS+vbti+vXryM6OvqZbf38/KBSqZCYmKgxPy0tDdnZ2eoz0KuDs7Ozxhne5R7v9QOAmZkZXnzxRSxfvhzx8fFYvHgxfv/9d/zxxx9PXHd5zoSEhArvXblyBXXq1IGdnZ1uP8BTvPbaazh//jxyc3OfeIJgue+++w49e/bE5s2b8eqrr6JXr14ICQmpsE8q+6WqMvLz8zFq1CgEBgZi/PjxWLZsGWJiYqpt/UQ1hYWcTNr06dNhZ2eHsWPHIi0trcL7169fx6pVqwCUDQ0DqHBm+fLlywEAffr0qbZcDRo0QE5ODv766y/1vHv37mHPnj0a7R48eFBh2fIbozx+SVw5Ly8vtGnTBlu3btUojH///TcOHjyo/jlrQs+ePbFo0SKsXbsWnp6eT21nbm5eobf/7bff4u7duxrzyr9wPOlLj7ZmzJiB27dvY+vWrVi+fDn8/f0RFhb21P1IZCh4QxgyaQ0aNEBUVBSGDRuGZs2aadzZ7eTJk/j2228xcuRIAEDr1q0RFhaGTZs2ITs7G927d8eZM2ewdetWDBgw4KmXNlXFq6++ihkzZmDgwIF47733UFBQgA0bNqBx48YaJ3stXLgQx44dQ58+feDn54f09HSsX78e9erVQ5cuXZ66/k8++QS9e/dGcHAwxowZg8LCQqxZswYymQzz58+vtp/jcWZmZpg9e/Yz2/Xt2xcLFy7EqFGj0KlTJ1y8eBE7duxA/fr1Ndo1aNAATk5O2LhxIxwcHGBnZ4eOHTsiICBAq1y///471q9fj3nz5qkvh4uMjESPHj0wZ84cLFu2TKv1EemVyGfNExmEq1evCuPGjRP8/f0FKysrwcHBQejcubOwZs0aoaioSN2upKREWLBggRAQECBYWloKPj4+wqxZszTaCELZ5Wd9+vSpsJ3HL3t62uVngiAIBw8eFFq0aCFYWVkJTZo0Eb766qsKl58dPnxY6N+/v+Dt7S1YWVkJ3t7ewvDhw4WrV69W2Mbjl2j99ttvQufOnQUbGxvB0dFR6NevnxAfH6/Rpnx7j1/eFhkZKQAQbty48dR9Kgial589zdMuP3v//fcFLy8vwcbGRujcubMQHR39xMvGfvjhByEwMFCwsLDQ+Dm7d+8uNG/e/InbfHQ9crlc8PPzE9q1ayeUlJRotJs8ebJgZmYmREdH/+vPQCQmiSBocbYKERERGRQeIyciIjJiLORERERGjIWciIjIiLGQExERGTEWciIiIiPGQk5ERGTEWMiJiIiMG
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)\n",
"disp.figure_.suptitle(\"Confusion Matrix\")\n",
"print(f\"Confusion matrix:\\n{disp.confusion_matrix}\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the results from evaluating a classifier are stored in the form of a\n",
"`confusion matrix <confusion_matrix>` and not in terms of `y_true` and\n",
"`y_pred`, one can still build a :func:`~sklearn.metrics.classification_report`\n",
"as follows:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report rebuilt from confusion matrix:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.99 0.99 88\n",
" 1 0.99 0.97 0.98 91\n",
" 2 0.99 0.99 0.99 86\n",
" 3 0.98 0.87 0.92 91\n",
" 4 0.99 0.96 0.97 92\n",
" 5 0.95 0.97 0.96 91\n",
" 6 0.99 0.99 0.99 91\n",
" 7 0.96 0.99 0.97 89\n",
" 8 0.94 1.00 0.97 88\n",
" 9 0.93 0.98 0.95 92\n",
"\n",
" accuracy 0.97 899\n",
" macro avg 0.97 0.97 0.97 899\n",
"weighted avg 0.97 0.97 0.97 899\n",
"\n",
"\n"
]
}
],
"source": [
"# The ground truth and predicted lists\n",
"y_true = []\n",
"y_pred = []\n",
"cm = disp.confusion_matrix\n",
"\n",
"# For each cell in the confusion matrix, add the corresponding ground truths\n",
"# and predictions to the lists\n",
"for gt in range(len(cm)):\n",
" for pred in range(len(cm)):\n",
" y_true += [gt] * cm[gt][pred]\n",
" y_pred += [pred] * cm[gt][pred]\n",
"\n",
"print(\n",
" \"Classification report rebuilt from confusion matrix:\\n\"\n",
" f\"{metrics.classification_report(y_true, y_pred)}\\n\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}