82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
# import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import scipy.spatial
|
|
from skimage import io
|
|
import time
|
|
|
|
def distance(points,Pc):
|
|
return scipy.spatial.distance.cdist(points[:,:], Pc[:,:])
|
|
|
|
def kmeans(points = [0,0], K = 1):
|
|
# Initialisation K prototypes
|
|
dim = points.shape[1]
|
|
N = points.shape[0]
|
|
iter = 0
|
|
eps = 0.1
|
|
Pc_index = []
|
|
Pc_save = np.zeros([K,dim])
|
|
clusters = []
|
|
|
|
for i in range(0,K):
|
|
Pc_index.append(np.random.randint(0,N))
|
|
Pc = points[Pc_index,:]
|
|
|
|
while (np.mean(distance(Pc,Pc_save)) > eps and iter < 3):
|
|
iter += 1
|
|
Pc_save = Pc
|
|
# print(Pc)
|
|
# print(points[:,:Pc.shape[0]])
|
|
dist = distance(points=points[:,:Pc.shape[1]],Pc=Pc)
|
|
clust = np.argmin(dist, axis=1)
|
|
clust = np.expand_dims(clust, axis=0)
|
|
points = np.append(points[:,:Pc.shape[1]], clust.T, axis=1)
|
|
# print(points)
|
|
Pc = np.zeros([K,dim])
|
|
index = np.array([])
|
|
|
|
for n in range(0,N):
|
|
for k in range(0,K):
|
|
index = np.append(index, (clust==k).sum())
|
|
if points[n,-1] == k:
|
|
# print(points)
|
|
# print(Pc)
|
|
Pc[k,:] = np.add(Pc[k,:], points[n,:-1])
|
|
|
|
for k in range(0,K):
|
|
Pc[k,:] = np.divide(Pc[k,:],index[k])
|
|
|
|
# print(Pc)
|
|
indice = points[:,-1]
|
|
points = points[:,:-1]
|
|
return Pc, indice, points
|
|
|
|
def mat_2_img(mat,my_img):
|
|
img_seg = mat.reshape(my_img.shape[0], my_img.shape[1], my_img.shape[2])
|
|
return img_seg
|
|
|
|
def img_2_mat(my_img):
|
|
mat = my_img.reshape(my_img.shape[0]*my_img.shape[1],my_img.shape[2])
|
|
return mat
|
|
|
|
def kmeans_image(path_image, K):
|
|
my_img = io.imread(path_image)
|
|
# imgplot = plt.imshow(my_img)
|
|
Mat = img_2_mat(my_img)
|
|
|
|
Pc, index, clusters = kmeans(Mat, K)
|
|
|
|
for k in range(Mat.shape[0]):
|
|
Mat[k,:] = np.floor(Pc[index[k],:])
|
|
|
|
img_seg = mat_2_img(Mat, my_img)
|
|
|
|
io.imsave(path_image.split('.')[0] + "_%d.jpg" % K, img_seg)
|
|
# imgplot = plt.imshow(img_seg)
|
|
return Pc, index, img_seg
|
|
|
|
path_image = "fruits.jpg"
|
|
|
|
start_time = time.time()
|
|
Pc, index, img_seg = kmeans_image(path_image=path_image, K=2)
|
|
end_time = time.time()
|
|
print(f"It took {end_time-start_time:.2f} seconds to compute")
|