You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.3 KiB
Python

# import matplotlib.pyplot as plt
import numpy as np
import scipy.spatial
from skimage import io
import time
def distance(points,Pc):
return scipy.spatial.distance.cdist(points[:,:], Pc[:,:])
def kmeans(points = [0,0], K = 1):
# Initialisation K prototypes
dim = points.shape[1]
N = points.shape[0]
iter = 0
eps = 0.1
Pc_index = []
Pc_save = np.zeros([K,dim])
clusters = []
for i in range(0,K):
Pc_index.append(np.random.randint(0,N))
Pc = points[Pc_index,:]
while (np.mean(distance(Pc,Pc_save)) > eps and iter < 3):
iter += 1
Pc_save = Pc
# print(Pc)
# print(points[:,:Pc.shape[0]])
dist = distance(points=points[:,:Pc.shape[1]],Pc=Pc)
clust = np.argmin(dist, axis=1)
clust = np.expand_dims(clust, axis=0)
points = np.append(points[:,:Pc.shape[1]], clust.T, axis=1)
# print(points)
Pc = np.zeros([K,dim])
index = np.array([])
for n in range(0,N):
for k in range(0,K):
index = np.append(index, (clust==k).sum())
if points[n,-1] == k:
# print(points)
# print(Pc)
Pc[k,:] = np.add(Pc[k,:], points[n,:-1])
for k in range(0,K):
Pc[k,:] = np.divide(Pc[k,:],index[k])
# print(Pc)
indice = points[:,-1]
points = points[:,:-1]
return Pc, indice, points
def mat_2_img(mat,my_img):
img_seg = mat.reshape(my_img.shape[0], my_img.shape[1], my_img.shape[2])
return img_seg
def img_2_mat(my_img):
mat = my_img.reshape(my_img.shape[0]*my_img.shape[1],my_img.shape[2])
return mat
def kmeans_image(path_image, K):
my_img = io.imread(path_image)
# imgplot = plt.imshow(my_img)
Mat = img_2_mat(my_img)
Pc, index, clusters = kmeans(Mat, K)
for k in range(Mat.shape[0]):
Mat[k,:] = np.floor(Pc[index[k],:])
img_seg = mat_2_img(Mat, my_img)
io.imsave(path_image.split('.')[0] + "_%d.jpg" % K, img_seg)
# imgplot = plt.imshow(img_seg)
return Pc, index, img_seg
path_image = "fruits.jpg"
start_time = time.time()
Pc, index, img_seg = kmeans_image(path_image=path_image, K=2)
end_time = time.time()
print(f"It took {end_time-start_time:.2f} seconds to compute")