You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

289 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TP1 KMEANS\n",
"\n",
"On nous propose de coder l'algorithme des kmeans afin de faire du clustering sur 2 classes puis plus de 2 classes.\n",
"Plus tard, on utilisera notre algorithme pour segmenter une image sur l'information de couleur."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import scipy.spatial"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# mean = [1,2,3,4]\n",
"# sd = [0.25, 0.25, 0.1, 0.2]\n",
"# clusters = 4\n",
"mean = [1,2]\n",
"sd = [0.25, 0.25]\n",
"dim = 2\n",
"nb = 10\n",
"clusters = 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fonctions à utiliser pour le clustering"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def gen_points(mean=1,sd=0.5, nb=100, dim=2, clusters=2):\n",
" size = []\n",
" # for i in range(0,dim):\n",
" size.append(nb)\n",
" size.append(dim)\n",
" points = np.random.normal(mean[0],sd[0],size=size)\n",
" for i in range(1,clusters):\n",
" points = np.concatenate((points,np.random.normal(mean[i],sd[i],size=size)),axis=0)\n",
" \n",
" return points"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def distance(points,Pc): \n",
" return scipy.spatial.distance.cdist(points[:,:], points[:,:])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def kmeans(points = [0,0], K = 1, nb=1, dim=2):\n",
" # Initialisation K prototypes\n",
" Pc_index = []\n",
" for i in range(0,K):\n",
" Pc_index.append(np.random.randint(0,nb*dim))\n",
" Pc = points[Pc_index,:]\n",
"\n",
" return Pc"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def visualisation(points, Pc=[0,0], dim=2):\n",
" if(dim==2):\n",
" plt.plot(points[:,0], points[:,1], 'o')\n",
" plt.plot(Pc[:,0],Pc[:,1],'r+')\n",
" plt.grid(True)\n",
" plt.axis([min(mean)-1,max(mean)+1,min(mean)-1,max(mean)+1])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20, 2)\n"
]
}
],
"source": [
"points = gen_points(mean,sd,nb,dim,clusters)\n",
"print(points.shape)\n",
"# print(points)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 0.39108103 0.34583518 0.70546644 0.3203134 0.3516725\n",
" 0.67971143 0.12125982 0.61902803 0.25895127 1.56135251 1.16925868\n",
" 0.89449237 1.43352053 1.0743239 0.93510189 1.41988547 1.58011875\n",
" 1.08411331 1.18671248]\n",
" [0.39108103 0. 0.41265824 0.31753789 0.20134828 0.1520557\n",
" 0.29442779 0.48416069 0.32820193 0.29379395 1.92061143 1.55795192\n",
" 1.2792717 1.80777689 1.45979562 1.32616725 1.80571725 1.97050583\n",
" 1.47031622 1.56670796]\n",
" [0.34583518 0.41265824 0. 0.62859739 0.21220435 0.49592519\n",
" 0.67380419 0.31854927 0.73773146 0.51626889 1.8336252 1.37308473\n",
" 1.01281801 1.49334278 1.29787751 1.11001289 1.52748312 1.75952011\n",
" 1.30415796 1.42577374]\n",
" [0.70546644 0.31753789 0.62859739 0. 0.44061757 0.42670666\n",
" 0.13916838 0.7887133 0.38129538 0.58971012 2.2340252 1.87432417\n",
" 1.58220205 2.09814273 1.77707512 1.63872881 2.10818677 2.28550213\n",
" 1.78747095 1.88420608]\n",
" [0.3203134 0.20134828 0.21220435 0.44061757 0. 0.30494113\n",
" 0.46740546 0.36865442 0.52868835 0.37372177 1.88068426 1.46844245\n",
" 1.14902285 1.65807412 1.3803439 1.21975293 1.67330997 1.8710176\n",
" 1.38894595 1.49825024]\n",
" [0.3516725 0.1520557 0.49592519 0.42670666 0.30494113 0.\n",
" 0.35637593 0.46606861 0.27034798 0.16380212 1.81240225 1.48388867\n",
" 1.24193007 1.78456017 1.37947088 1.26763783 1.76514842 1.90097935\n",
" 1.39123292 1.47795072]\n",
" [0.67971143 0.29442779 0.67380419 0.13916838 0.46740546 0.35637593\n",
" 0. 0.77794127 0.24735327 0.51809049 2.16596566 1.83655551\n",
" 1.57217378 2.1022043 1.7339787 1.61275974 2.09839429 2.25219229\n",
" 1.74541605 1.83405921]\n",
" [0.12125982 0.48416069 0.31854927 0.7887133 0.36865442 0.46606861\n",
" 0.77794127 0. 0.73558536 0.38006228 1.52663824 1.09978826\n",
" 0.79543522 1.32585453 1.01264344 0.85309758 1.32192905 1.50347121\n",
" 1.02093103 1.13291852]\n",
" [0.61902803 0.32820193 0.73773146 0.38129538 0.52868835 0.27034798\n",
" 0.24735327 0.73558536 0. 0.38834 1.99112336 1.71317083\n",
" 1.50027904 2.04749497 1.6029451 1.51251266 2.01963574 2.13226014\n",
" 1.61597543 1.69022564]\n",
" [0.25895127 0.29379395 0.51626889 0.58971012 0.37372177 0.16380212\n",
" 0.51809049 0.38006228 0.38834 0. 1.64887732 1.33143897\n",
" 1.11366524 1.66342238 1.22437768 1.12471467 1.63153555 1.74982768\n",
" 1.23666771 1.31925162]\n",
" [1.56135251 1.92061143 1.8336252 2.2340252 1.88068426 1.81240225\n",
" 2.16596566 1.52663824 1.99112336 1.64887732 0. 0.58770443\n",
" 1.02271415 1.15136496 0.58808445 0.83705754 0.94499234 0.6037414\n",
" 0.59102406 0.44614399]\n",
" [1.16925868 1.55795192 1.37308473 1.87432417 1.46844245 1.48388867\n",
" 1.83655551 1.09978826 1.71317083 1.33143897 0.58770443 0.\n",
" 0.44427897 0.64604338 0.12332713 0.27518253 0.47328172 0.41918021\n",
" 0.10638777 0.1741561 ]\n",
" [0.89449237 1.2792717 1.01281801 1.58220205 1.14902285 1.24193007\n",
" 1.57217378 0.79543522 1.50027904 1.11366524 1.02271415 0.44427897\n",
" 0. 0.55487036 0.44065169 0.19125029 0.52649414 0.76267241\n",
" 0.43456127 0.57921824]\n",
" [1.43352053 1.80777689 1.49334278 2.09814273 1.65807412 1.78456017\n",
" 2.1022043 1.32585453 2.04749497 1.66342238 1.15136496 0.64604338\n",
" 0.55487036 0. 0.7423045 0.62079242 0.20929273 0.61317826\n",
" 0.7258323 0.81728057]\n",
" [1.0743239 1.45979562 1.29787751 1.77707512 1.3803439 1.37947088\n",
" 1.7339787 1.01264344 1.6029451 1.22437768 0.58808445 0.12332713\n",
" 0.44065169 0.7423045 0. 0.25088095 0.58359133 0.5339439\n",
" 0.01763323 0.14206747]\n",
" [0.93510189 1.32616725 1.11001289 1.63872881 1.21975293 1.26763783\n",
" 1.61275974 0.85309758 1.51251266 1.12471467 0.83705754 0.27518253\n",
" 0.19125029 0.62079242 0.25088095 0. 0.52473292 0.65180514\n",
" 0.24624553 0.3914113 ]\n",
" [1.41988547 1.80571725 1.52748312 2.10818677 1.67330997 1.76514842\n",
" 2.09839429 1.32192905 2.01963574 1.63153555 0.94499234 0.47328172\n",
" 0.52649414 0.20929273 0.58359133 0.52473292 0. 0.40874924\n",
" 0.56617043 0.63624877]\n",
" [1.58011875 1.97050583 1.75952011 2.28550213 1.8710176 1.90097935\n",
" 2.25219229 1.50347121 2.13226014 1.74982768 0.6037414 0.41918021\n",
" 0.76267241 0.61317826 0.5339439 0.65180514 0.40874924 0.\n",
" 0.51939313 0.48569502]\n",
" [1.08411331 1.47031622 1.30415796 1.78747095 1.38894595 1.39123292\n",
" 1.74541605 1.02093103 1.61597543 1.23666771 0.59102406 0.10638777\n",
" 0.43456127 0.7258323 0.01763323 0.24624553 0.56617043 0.51939313\n",
" 0. 0.14516702]\n",
" [1.18671248 1.56670796 1.42577374 1.88420608 1.49825024 1.47795072\n",
" 1.83405921 1.13291852 1.69022564 1.31925162 0.44614399 0.1741561\n",
" 0.57921824 0.81728057 0.14206747 0.3914113 0.63624877 0.48569502\n",
" 0.14516702 0. ]]\n"
]
}
],
"source": [
"dist = distance(points,Pc=[0,0])\n",
"print(dist)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1.2631266 0.8529462 ]\n",
" [1.1325475 1.17318217]]\n"
]
}
],
"source": [
"Pc = kmeans(points,K=2,nb=nb,dim=dim)\n",
"print(Pc)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"visualisation(points, Pc, dim=dim)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.10 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}