K-Means Clustering Algorithm#
Learning Outcomes#
This example teaches how to implement k-means clustering algorithm using NumPy and is based on the k-means example in cuPyNumeric.
In this example, you will learn: * how to compute pairwise distances using newaxis
* how to “bin” data using the bincount
* how to locate data using where
and argmin
* how to use boolean mask to select data * how to compute norm
of an n-dimensional array * how to determine if two NumPy arrays don’t have the same values using not_equal
. * how to plot clusters using matplotlib
Background#
To learn more about the clustering algorithm, see this material.
Implementation#
[1]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from typing import Tuple
For reproducibility, use a seed to generate a default random number generator
[2]:
# seed the random number generator for deterministic runs
random_seed: int = 42
rng = np.random.default_rng(random_seed)
[3]:
def initialize(n_elements: int, n_dims: int, n_centroids: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Randomly initalize data and centroids of the clusters.
n_elements: int
Number of elements/observations that need to be clusters
n_dims: int
Dimension of the elements/observations
n_centroids: int
Number of clusters
Returns:
A Tuple with observations and centroids
"""
data = rng.random((n_elements, n_dims))
centroids = rng.random((n_centroids, n_dims))
return data, centroids
[4]:
def calculate_distances(data: np.ndarray, centroids: np.ndarray, data_magnitude_squared: np.ndarray) -> np.ndarray:
"""
Return pairwise distance between the data and centroids.
data: np.ndarray
Observations that need to be clustered
centroids: np.ndarray
The center of the clusters
data_magnitude_squared: np.ndarray
Square of magnitude of observations (|y|^2)
Returns: np.ndarray
"""
centroid_dots = np.square(np.linalg.norm(centroids, ord=2, axis=1))
pairwise_distances = (
data_magnitude_squared[:, np.newaxis] + centroid_dots[np.newaxis, :]
)
# ||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y
# pairwise_distances has ||x||^2 + ||y||^2, so beta = 1
# The gemm calculates x.y for all x and y, so alpha = -2.0
pairwise_distances -= 2.0 * np.dot(data, centroids.T)
return pairwise_distances
[5]:
def relabel(pairwise_distances: np.ndarray) -> np.ndarray:
return np.argmin(pairwise_distances, axis=1)
[6]:
def find_centroids(
centroids: np.ndarray,
data: np.ndarray,
labels: np.ndarray,
pairwise_distances: np.ndarray,
zero_point: np.ndarray,
n_centroids: int
) -> np.ndarray:
"""Find centroids following the algorithm in the reference mentioned earlier
centroids: np.ndarray
The center of the clusters
data: np.ndarray
Observations that need to be clustered
labels: np.ndarray
The clusters the data belong to
pairwise_distances: np.ndarray
Pairwise distance between each data point and centroid
zero_point: np.ndarray
n_centroids: np.ndarray
Number of clusters
"""
# Get the number of points associated with each centroid
counts = np.bincount(labels, minlength=n_centroids)
# Build label masks for each centroid and sum across all the
# points assocated with each new centroid
distance_sum = 0.0
for idx in range(n_centroids):
# Boolean mask indicating where the points are for this center
centroid_mask = labels == idx
centroids[idx, :] = np.sum(
np.where(centroid_mask[..., np.newaxis], data, zero_point), axis=0
)
distance_sum += np.sum(
np.where(centroid_mask, pairwise_distances[:, idx], 0.0)
)
# To avoid introducing divide by zero errors
# If a centroid has no weight, we'll do no normalization
# This will keep its coordinates defined.
counts = np.maximum(counts, np.ones((1,), dtype=np.uint64))
centroids /= counts[:, np.newaxis]
return distance_sum
[7]:
def run_kmeans(
n_centroids: int,
n_dims: int,
n_iters: int,
n_elements: int,
n_iter_check: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate observations and cluster them into requested number of clusters.
n_centroids: int
Number of clusters
n_dims: int
Dimension of the elements/observations
n_iters: int
Maximum number of iterations
n_elements: int
Number of elements/observations that need to be clusters
n_iter_check: int
Determines how often we check for convergence.
"""
print("Running kmeans...")
print("Number of data points: " + str(n_elements))
print("Number of dimensions: " + str(n_dims))
print("Number of centroids: " + str(n_centroids))
print("Max iterations: " + str(n_iters))
data, centroids = initialize(n_elements, n_dims, n_centroids)
data_magnitude_squared = np.square(np.linalg.norm(data, ord=2, axis=1))
zero_point = np.zeros((1, data.shape[1]), dtype=data.dtype)
labels = None
iteration = 0
prior_distance_sum = None
# We run for max iterations or until we converge
# We only test convergence every n_iter_check iterations
while iteration < n_iters:
pairwise_distances = calculate_distances(data, centroids, data_magnitude_squared)
new_labels = relabel(pairwise_distances)
distance_sum = find_centroids(
centroids,
data,
new_labels,
pairwise_distances,
zero_point,
n_centroids,
)
if iteration > 0 and iteration % n_iter_check == 0:
changes = np.not_equal(labels, new_labels)
total_changes = np.sum(changes)
delta = distance_sum / prior_distance_sum
if delta > 1 - 0.000001:
break
prior_distance_sum = distance_sum
labels = new_labels
iteration += 1
return data, labels, centroids
Lets run the kmeans algorithm using a set of inputs#
[8]:
n_centroids: int = 5
n_dims: int = 2
n_elements: int = 256
n_iter_check: int = 10
n_iters: int = 100
data, labels, centroids = run_kmeans(n_centroids, n_dims, n_iters, n_elements, n_iter_check)
Running kmeans...
Number of data points: 256
Number of dimensions: 2
Number of centroids: 5
Max iterations: 100
Generate a color map to differentiate the clusters
[9]:
label_color_map = {0: 'blue', 1: 'black', 2: 'red', 3: 'magenta', 4:'yellow', 5: 'green', 6:'gray'}
# make sure we have unique color for each cluster (total number of clusters specified by n_centroids)
assert len(label_color_map.items()) >= n_centroids
Plot the clusters. Each color represents a cluster
[10]:
# lets plot the data colored by the cluster they belong to
for label_value, label_color in label_color_map.items():
index = (labels == label_value)
plt.plot(data[index, 0], data[index, 1], 'o', color=label_color)
# lets plot the centroid of the clusters
plt.scatter(centroids[:, 0], centroids[:, 1], s = 320, marker='*', c=list(label_color_map.values())[0:n_centroids], edgecolors='gray');
plt.show()

Exercise: Change the number of clusters and number of observations and see how the clusters change#
[ ]: