un llyod max qui marche
This commit is contained in:
parent
8cae2bbf6d
commit
bc09d95645
1 changed files with 44 additions and 49 deletions
|
@ -1,58 +1,53 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sipy import integrate
|
from scipy import integrate
|
||||||
from scipy import norm
|
from scipy.stats import norm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
M = 8
|
|
||||||
X = np.random.normal(0,1,1000)
|
|
||||||
|
|
||||||
def ddp(x):
|
def ddp(x):
|
||||||
mean = 0,
|
mean = 0,
|
||||||
sigma = 1
|
sigma = 1
|
||||||
return norm.pdf(x,mean,sigma)
|
return norm.pdf(x,mean,sigma)
|
||||||
|
|
||||||
def init_thres_vec(M,X):
|
def quant(centroids, X):
|
||||||
step = (np.max(X)-np.min(X))/M
|
bornes = (centroids[:-1]+centroids[1:])/2
|
||||||
thres_intervals = np.array([])
|
bornes = np.insert(bornes,0,-np.inf)
|
||||||
mid = np.mean(X)
|
bornes = np.append(bornes,np.inf)
|
||||||
for i in range(int(M/2)):
|
xquant =np.zeros(len(X))
|
||||||
thres_intervals = np.append(thres_vec,mid+(i+1)*step)
|
for k in range(len(X)):
|
||||||
thres_intervals = np.insert(thtres_vec,0,mid-(1+1)*step)
|
for i in range(len(bornes)):
|
||||||
return thres_intervals
|
if X[k]>=bornes[i] and X[k] <bornes[i+1]:
|
||||||
|
xquant[k] = centroids[i]
|
||||||
def quant(x,thres,intervals):
|
return xquant
|
||||||
thres= np.append(thres, np.inf)
|
def llyodMax(X,M,maxiter=1000,eps=1e-6):
|
||||||
thres= np.insert(thres, 0, -np.inf)
|
#répartition uniforme des bornes
|
||||||
x_hat_q = np.zeros(np.shape(x))
|
step = (np.max(X)-np.min(X))/(M-2)
|
||||||
for i in range(len(thres)-1):
|
Xmin = np.min(X)
|
||||||
if i == 0:
|
bornes = np.array([i*step+Xmin for i in range(M-1)])
|
||||||
x_hat_q = np.where(np.logical_and(x > thres[i], x <= thres[i+1]),
|
bornes = np.insert(bornes,0,-np.inf)
|
||||||
np.full(np.size(x_hat_q), intervals[i]), x_hat_q)
|
bornes = np.append(bornes,np.inf)
|
||||||
elif i == range(len(thres))[-1]-1:
|
centroids = np.zeros(M)
|
||||||
x_hat_q = np.where(np.logical_and(x > thres[i], x <= thres[i+1]),
|
for it in range(maxiter):
|
||||||
np.full(np.size(x_hat_q), intervals[i]), x_hat_q)
|
old_centroids = centroids.copy()
|
||||||
else:
|
for i in range(M):
|
||||||
x_hat_q = np.where(np.logical_and(x > thres[i], x < thres[i+1]),
|
centroids[i] = integrate.quad(lambda x: x*ddp(x),bornes[i],bornes[i+1])[0]\
|
||||||
np.full(np.size(x_hat_q), intervals[i]), x_hat_q)
|
/integrate.quad(lambda x: ddp(x),bornes[i],bornes[i+1])[0]
|
||||||
return x_hat_q
|
bornes[1:-1] = (centroids[:-1]+centroids[1:])/2
|
||||||
|
err = np.linalg.norm(centroids-old_centroids)
|
||||||
|
print(err)
|
||||||
def LlyodMax(X,intervals, max_iter=1000,eps=1e-5):
|
if err < eps :
|
||||||
err_min = np.inf
|
|
||||||
for i in range(max_iter):
|
|
||||||
for j in range(len(x_hat_q)):
|
|
||||||
centroids[i] = integrate.quad(lambda x : x*ddp(x),
|
|
||||||
intervals[j],intervals[j+1])[0]/
|
|
||||||
integrate.quad(lambda x : ddp(x),
|
|
||||||
intervals[j],intervals[j+1])[0]
|
|
||||||
intervals = 0.5*(centroids[1:]+centroids[:-1])
|
|
||||||
x_hat = quant(X,centroids,intervals)
|
|
||||||
err = np.linalg.norm(X-x_hat)
|
|
||||||
if err < err_min:
|
|
||||||
err_min =err
|
|
||||||
intervals_min = intervals
|
|
||||||
centroids_min = centroids
|
|
||||||
if err_min< 1e-5:
|
|
||||||
break
|
break
|
||||||
best_x_hat = quant(X,centroids_min,intervals_min)
|
return centroids
|
||||||
return best_x_hat
|
|
||||||
|
M = 4
|
||||||
|
X = np.random.normal(0,1,1000)
|
||||||
|
centroids = llyodMax(X,M)
|
||||||
|
bornes = (centroids[:-1]+centroids[1:])/2
|
||||||
|
bornes = np.insert(bornes,0,-np.inf)
|
||||||
|
bornes = np.append(bornes,np.inf)
|
||||||
|
|
||||||
|
print(centroids, bornes)
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(X)
|
||||||
|
plt.plot(quant(bornes,X))
|
||||||
|
plt.show()
|
Loading…
Reference in a new issue