better init

This commit is contained in:
Pierre-antoine Comby 2019-05-05 15:44:16 +02:00
parent 9ed86888cf
commit 0f84d76db2

30
455-Codage_Sources/algo_code/LBG.py Normal file → Executable file
View file

@ -1,36 +1,31 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Created on Sun May 5 13:59:37 2019
@author: pac
Algorithme de Linde-Buzo-Gray, version 2D
"""
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.spatial import Voronoi, voronoi_plot_2d from scipy.spatial import Voronoi, voronoi_plot_2d
#
mean= [0,0]
cov = [[1,0],[0,1]]
M = 20; M = 20;
N =100; #point par cluster N =100; #point par cluster
K = N*M K = N*M
means = np.random.rand(M,2)*10 means = np.random.rand(M,2)*10
X = np.zeros((K,2)) X = np.zeros((K,2))
plt.figure() plt.figure()
cov = np.array([[1,0],[0,1]])
for m in range(M): for m in range(M):
xi = np.random.multivariate_normal(means[m,:],cov,N) xi = np.random.multivariate_normal(means[m,:],cov,N)
X[m*N:(m+1)*N] = xi X[m*N:(m+1)*N] = xi
plt.plot(xi[:,0],xi[:,1],'+') plt.plot(xi[:,0],xi[:,1],'+')
plt.plot(means[:,0],means[:,1],'ob') plt.plot(means[:,0],means[:,1],'ob')
plt.show()
# X = np.random.multivariate_normal(mean,cov,K) mean= np.mean(X,axis=0)
Y0 = np.random.multivariate_normal(mean, cov,M)
Y0 = means; Y0 = np.random.multivariate_normal(mean, 10*cov, M)
plt.show()
print(Y0)
Y0= means
plt.plot(Y0[:,0],Y0[:,1],'ok')
def LBG(X,Y0,eps=1e-5,maxiter=1000): def LBG(X,Y0,eps=1e-5,maxiter=1000):
Y = Y0.copy() Y = Y0.copy()
old_dist = np.inf old_dist = np.inf
@ -46,6 +41,7 @@ def LBG(X,Y0,eps=1e-5,maxiter=1000):
dist += sum((X[k]-quant_min)**2) dist += sum((X[k]-quant_min)**2)
for j in range(len(Y)): for j in range(len(Y)):
Y[j,:] = np.mean(X[cluster_index==j],axis=0) Y[j,:] = np.mean(X[cluster_index==j],axis=0)
print(Y)
if dist-old_dist < eps: if dist-old_dist < eps:
break break
else: else:
@ -56,6 +52,6 @@ vor = Voronoi(Y)
voronoi_plot_2d(vor,show_vertices=False) voronoi_plot_2d(vor,show_vertices=False)
print(Y) print(Y)
plt.plot(X[:,0],X[:,1],'+') plt.plot(X[:,0],X[:,1],'+')
plt.plot(Y[:,0],Y[:,1],'o') plt.plot(Y[:,0],Y[:,1],'ob')
#plt.plot(Y0[:,0],Y0[:,1],'ob') plt.plot(Y0[:,0],Y0[:,1],'ok')
plt.show() plt.show()