M2_SETI/D3/TP/TP_SETI_Kmeans/Kmeans_cuda.py

83 lines
2.3 KiB
Python
Raw Normal View History

2022-12-22 19:38:17 +01:00
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import time
# Load the image and convert it to a NumPy array
from PIL import Image
im = Image.open('fruits.jpg')
im_data = np.array(im)
# Convert the image data to float32 and normalize it
im_data = im_data.astype(np.float32) / 255
# Create a CUDA kernel to perform K-means clustering
kernel = """
__global__ void kmeans(float *data, int *labels, float *centroids, int n, int k, int dim)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n)
return;
float min_dist = 10000;
int min_centroid = -1;
for (int i = 0; i < k; i++)
{
float dist = 0.0;
for (int j = 0; j < dim; j++)
{
float diff = data[tid * dim + j] - centroids[i * dim + j];
dist += diff * diff;
}
if (dist < min_dist)
{
min_dist = dist;
min_centroid = i;
}
}
labels[tid] = min_centroid;
}
"""
mod = SourceModule(kernel)
kmeans = mod.get_function("kmeans")
# Set the number of clusters and the number of iterations
k = 2
n_iter = 5
# Initialize the centroids and labels
centroids = np.random.rand(k, im_data.shape[-1]).astype(np.float32)
labels = np.zeros(im_data.shape[:2], dtype=np.int32)
def replace_with_nearest_centroid(centroids, colors):
# Compute the distance between each color and each centroid
distances = np.sqrt(np.sum((colors[:, :] - centroids) ** 2, axis=2))
# Find the index of the centroid that is nearest to each color
nearest_centroids = np.argmin(distances, axis=1)
# Replace each color with the nearest centroid
colors[:] = centroids[nearest_centroids]
start_time = time.time()
# Run the K-means algorithm
for _ in range(n_iter):
kmeans(cuda.In(im_data), cuda.Out(labels), cuda.In(centroids), np.int32(im_data.shape[0] * im_data.shape[1]), np.int32(k), np.int32(im_data.shape[-1]), block=(1024,1,1), grid=(im_data.shape[0] * im_data.shape[1] // 1024 + 1, 1))
# Update the centroids
for i in range(k):
centroids[i] = np.mean(im_data[labels == i], axis=0)
replace_with_nearest_centroid(centroids=centroids, colors=im_data)
# Convert the labels back to the original image format
labels = labels
end_time = time.time()
print(f"It took {end_time-start_time:.2f} seconds to compute")