Color quantization with a self-organizing map

Pekka Väänänen | 30fps.net | March 8th, 2024

The popular ScreenToGif recording tool includes a high quality color quantizer, called NeuralQuantizer. I assumed it would be a simple fully connected network but it's actually a one-dimensional self-organizing map! I have to admit this was actually quite exciting to me because it's the first time I encounter this classic machine learning model in the wild.

I implemented something similar in Python using some Riley Smith's neat sklearn-som library. This notebook shows how to do it.

In [35]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

image = Image.open('krull2.png') 
image = image.resize((image.size[0]//4, image.size[1]//4))
image_data = np.array(image)[...,:3]/255. # RGB in [0,1]
image_flat = image_data.reshape(-1,3)     # Nx3 shape

plt.imshow(image_data)
Out[35]:
<matplotlib.image.AxesImage at 0x7fe38e2cd9d0>

We initialize the palette ("weights") to a greyscale gradient like NeuralQuantizer does. This seems to give a much better starting point than the normally distributed noise done by default.

In [36]:
from sklearn_som.som import SOM
import time

start = time.time()
M = 256
print(f"Fitting a palette of {M} colors")

som = SOM(m=M, n=1, dim=3, lr=1.0, sigma=2, max_iter=3000, random_state=1234)
# Start with a greyscale palette. Create an Mx3 array with values in range [0,1].
som.weights = np.tile(np.linspace(0,1,M)[:,np.newaxis], (1,3))

som.fit(image_flat)

# Compute clusters and assignments
assignments = som.predict(image_flat)
 # Extract a Mx3 array that's the colors the algorthm chose
palette = som.cluster_centers_.copy()[:,0,:]
assignments_image = assignments.reshape(*image_data.shape[:2])
output_image = np.take(palette, assignments, axis=0).reshape((*image_data.shape[:2],3))

# Take the palette indices that were actually used and pack them into an image.
uniq, counts = np.unique(assignments, return_counts=True)
used_colors = palette[uniq]
palette_image = np.zeros((16,M//16,3))
palette_image.reshape(-1,3)[:used_colors.shape[0]] = used_colors

print(f"Took {time.time()-start:.3f} seconds. Used {uniq.shape[0]} colors.")
Fitting a palette of 256 colors
Took 5.286 seconds. Used 205 colors.
In [37]:
# Plot the results
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
ax_input, ax_output, ax_assign, ax_palette = ax.flatten()

ax_input.imshow(image_data)
ax_input.set_title("Input image")
ax_output.imshow(output_image.clip(0,1))
ax_output.set_title(f"Output image ({used_colors.shape[0]} colors)")
ax_assign.imshow(assignments_image)
ax_assign.set_title("Palette indices")
ax_palette.imshow(palette_image.clip(0,1))
ax_palette.set_title("Palette")

for a in ax.flatten():
    a.axis('off')

plt.suptitle("Color palette found with a self-organizing map")
plt.tight_layout()
plt.show()