In [None]:
%load_ext lab_black

In [None]:
from itertools import combinations

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from sklearn.datasets import make_blobs
from sklearn.cluster import AgglomerativeClustering

sns.set()
sns.set_style("white")

In [None]:
def plot_clustering(data, labels, centroids=None, new_centroids=None):
    plt.figure(figsize=(8, 9))

    # plot data
    ax = sns.scatterplot(
        x=data[:, 0],
        y=data[:, 1],
        hue=labels,
        hue_order=sorted(np.unique(labels)),
        palette="muted",
        legend=False,
        s=100,
    )

    # plot centroids
    if centroids is not None:
        sns.scatterplot(
            x=centroids[:, 0],
            y=centroids[:, 1],
            marker="X",
            color="r" if new_centroids is None else "grey",
            s=200,
            ax=ax,
        )

    if new_centroids is not None:
        sns.scatterplot(
            x=new_centroids[:, 0],
            y=new_centroids[:, 1],
            marker="X",
            color="r",
            s=200,
            ax=ax,
        )

        for i in range(k):
            xs = [centroids[i, 0], new_centroids[i, 0]]
            ys = [centroids[i, 1], new_centroids[i, 1]]
            plt.plot(xs, ys, zorder=-1)

    ax.set_xticks([])
    ax.set_yticks([])

# Data

In [None]:
data, true_labels = make_blobs(
    n_features=2,
    centers=[[-2, -2], [-2, 0], [2, 2]],
    cluster_std=[0.4, 0.5, 1],
    random_state=42,
)

plot_clustering(data, true_labels)

# K-means

## Initialization

In [None]:
k = 3
labels = [-1] * data.shape[0]
centroids = np.array([[-3, 2], [0, -2], [2, 0]])
# centroids = np.array([[1, 4], [-3, -1], [2, 0]])
new_centroids = None
step = "E"

plot_clustering(data, labels, centroids, new_centroids)

## Iterations

In [None]:
print(step, "step")
if step == "E":
    # E step
    distances_to_centroids = cdist(data, centroids)
    new_labels = distances_to_centroids.argmin(axis=1)
    plot_clustering(data, new_labels, centroids)
    step = "M"
else:
    # M step
    new_centroids = np.array([data[new_labels == c].mean(axis=0) for c in range(k)])
    plot_clustering(data, new_labels, centroids, new_centroids)
    centroids = new_centroids
    step = "E"

# Hierarchical Agglomerative Clustering

In [None]:
k = data.shape[0]

In [None]:
print("k:", k)
hac = AgglomerativeClustering(n_clusters=k)
labels = hac.fit_predict(data)
plot_clustering(data, labels)
k -= 1