import numpy as np import pycuda.autoinit import pycuda.driver as cuda from pycuda.compiler import SourceModule import time # Load the image and convert it to a NumPy array from PIL import Image im = Image.open('fruits.jpg') im_data = np.array(im) # Convert the image data to float32 and normalize it im_data = im_data.astype(np.float32) / 255 # Create a CUDA kernel to perform K-means clustering kernel = """ __global__ void kmeans(float *data, int *labels, float *centroids, int n, int k, int dim) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid >= n) return; float min_dist = 10000; int min_centroid = -1; for (int i = 0; i < k; i++) { float dist = 0.0; for (int j = 0; j < dim; j++) { float diff = data[tid * dim + j] - centroids[i * dim + j]; dist += diff * diff; } if (dist < min_dist) { min_dist = dist; min_centroid = i; } } labels[tid] = min_centroid; } """ mod = SourceModule(kernel) kmeans = mod.get_function("kmeans") # Set the number of clusters and the number of iterations k = 2 n_iter = 5 # Initialize the centroids and labels centroids = np.random.rand(k, im_data.shape[-1]).astype(np.float32) labels = np.zeros(im_data.shape[:2], dtype=np.int32) def replace_with_nearest_centroid(centroids, colors): # Compute the distance between each color and each centroid distances = np.sqrt(np.sum((colors[:, :] - centroids) ** 2, axis=2)) # Find the index of the centroid that is nearest to each color nearest_centroids = np.argmin(distances, axis=1) # Replace each color with the nearest centroid colors[:] = centroids[nearest_centroids] start_time = time.time() # Run the K-means algorithm for _ in range(n_iter): kmeans(cuda.In(im_data), cuda.Out(labels), cuda.In(centroids), np.int32(im_data.shape[0] * im_data.shape[1]), np.int32(k), np.int32(im_data.shape[-1]), block=(1024,1,1), grid=(im_data.shape[0] * im_data.shape[1] // 1024 + 1, 1)) # Update the centroids for i in range(k): centroids[i] = np.mean(im_data[labels == i], axis=0) replace_with_nearest_centroid(centroids=centroids, colors=im_data) # Convert the labels back to the original image format labels = labels end_time = time.time() print(f"It took {end_time-start_time:.2f} seconds to compute")