82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
import numpy as np
|
|
import pycuda.autoinit
|
|
import pycuda.driver as cuda
|
|
from pycuda.compiler import SourceModule
|
|
import time
|
|
|
|
# Load the image and convert it to a NumPy array
|
|
from PIL import Image
|
|
im = Image.open('fruits.jpg')
|
|
im_data = np.array(im)
|
|
|
|
# Convert the image data to float32 and normalize it
|
|
im_data = im_data.astype(np.float32) / 255
|
|
|
|
# Create a CUDA kernel to perform K-means clustering
|
|
kernel = """
|
|
__global__ void kmeans(float *data, int *labels, float *centroids, int n, int k, int dim)
|
|
{
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (tid >= n)
|
|
return;
|
|
|
|
float min_dist = 10000;
|
|
int min_centroid = -1;
|
|
for (int i = 0; i < k; i++)
|
|
{
|
|
float dist = 0.0;
|
|
for (int j = 0; j < dim; j++)
|
|
{
|
|
float diff = data[tid * dim + j] - centroids[i * dim + j];
|
|
dist += diff * diff;
|
|
}
|
|
if (dist < min_dist)
|
|
{
|
|
min_dist = dist;
|
|
min_centroid = i;
|
|
}
|
|
}
|
|
labels[tid] = min_centroid;
|
|
}
|
|
"""
|
|
|
|
mod = SourceModule(kernel)
|
|
kmeans = mod.get_function("kmeans")
|
|
|
|
# Set the number of clusters and the number of iterations
|
|
k = 2
|
|
n_iter = 5
|
|
|
|
# Initialize the centroids and labels
|
|
centroids = np.random.rand(k, im_data.shape[-1]).astype(np.float32)
|
|
labels = np.zeros(im_data.shape[:2], dtype=np.int32)
|
|
|
|
def replace_with_nearest_centroid(centroids, colors):
|
|
# Compute the distance between each color and each centroid
|
|
distances = np.sqrt(np.sum((colors[:, :] - centroids) ** 2, axis=2))
|
|
|
|
# Find the index of the centroid that is nearest to each color
|
|
nearest_centroids = np.argmin(distances, axis=1)
|
|
|
|
# Replace each color with the nearest centroid
|
|
colors[:] = centroids[nearest_centroids]
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
# Run the K-means algorithm
|
|
for _ in range(n_iter):
|
|
kmeans(cuda.In(im_data), cuda.Out(labels), cuda.In(centroids), np.int32(im_data.shape[0] * im_data.shape[1]), np.int32(k), np.int32(im_data.shape[-1]), block=(1024,1,1), grid=(im_data.shape[0] * im_data.shape[1] // 1024 + 1, 1))
|
|
|
|
# Update the centroids
|
|
for i in range(k):
|
|
centroids[i] = np.mean(im_data[labels == i], axis=0)
|
|
|
|
replace_with_nearest_centroid(centroids=centroids, colors=im_data)
|
|
# Convert the labels back to the original image format
|
|
labels = labels
|
|
|
|
end_time = time.time()
|
|
print(f"It took {end_time-start_time:.2f} seconds to compute")
|
|
|
|
|