Author: Khalid Salama
Date created: 2021/02/28
Last modified: 2021/02/28
Description: Semantic Clustering by Adopting Nearest neighbors (SCAN) algorithm.
This example demonstrates how to apply the Semantic Clustering by Adopting Nearest neighbors (SCAN) algorithm (Van Gansbeke et al., 2020) on the CIFAR-10 dataset. The algorithm consists of two phases:
The example requires TensorFlow Addons, which you can install using the following command:
pip install tensorflow-addons
from collections import defaultdict
import random
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tqdm import tqdm
num_classes = 10
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])
print("x_data shape:", x_data.shape, "- y_data shape:", y_data.shape)
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
x_data shape: (60000, 32, 32, 3) - y_data shape: (60000, 1)
target_size = 32 # Resize the input images.
representation_dim = 512 # The dimensions of the features vector.
projection_units = 128 # The projection head of the representation learner.
num_clusters = 20 # Number of clusters.
k_neighbours = 5 # Number of neighbours to consider during cluster learning.
tune_encoder_during_clustering = False # Freeze the encoder in the cluster learning.
The data preprocessing step resizes the input images to the desired target_size
and applies
feature-wise normalization. Note that, when using keras.applications.ResNet50V2
as the
visual encoder, resizing the images into 255 x 255 inputs would lead to more accurate results
but require a longer time to train.
data_preprocessing = keras.Sequential(
[
layers.experimental.preprocessing.Resizing(target_size, target_size),
layers.experimental.preprocessing.Normalization(),
]
)
# Compute the mean and the variance from the data for normalization.
data_preprocessing.layers[-1].adapt(x_data)
Unlike simCLR, which randomly picks a single data augmentation function to apply to an input image, we apply a set of data augmentation functions randomly to the input image. (You can experiment with other image augmentation techniques by following the data augmentation tutorial.)
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomTranslation(
height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2), fill_mode="nearest"
),
layers.experimental.preprocessing.RandomFlip(mode="horizontal"),
layers.experimental.preprocessing.RandomRotation(
factor=0.15, fill_mode="nearest"
),
layers.experimental.preprocessing.RandomZoom(
height_factor=(-0.3, 0.1), width_factor=(-0.3, 0.1), fill_mode="nearest"
)
]
)
Display a random image
image_idx = np.random.choice(range(x_data.shape[0]))
image = x_data[image_idx]
image_class = classes[y_data[image_idx][0]]
plt.figure(figsize=(3, 3))
plt.imshow(x_data[image_idx].astype("uint8"))
plt.title(image_class)
_ = plt.axis("off")
Display a sample of augmented versions of the image
plt.figure(figsize=(10, 10))
for i in range(9):
augmented_images = data_augmentation(np.array([image]))
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[0].numpy().astype("uint8"))
plt.axis("off")
def create_encoder(representation_dim):
encoder = keras.Sequential(
[
keras.applications.ResNet50V2(
include_top=False, weights=None, pooling="avg"
),
layers.Dense(representation_dim),
]
)
return encoder
class RepresentationLearner(keras.Model):
def __init__(
self,
encoder,
projection_units,
num_augmentations,
temperature=1.0,
dropout_rate=0.1,
l2_normalize=False,
**kwargs
):
super(RepresentationLearner, self).__init__(**kwargs)
self.encoder = encoder
# Create projection head.
self.projector = keras.Sequential(
[
layers.Dropout(dropout_rate),
layers.Dense(units=projection_units, use_bias=False),
layers.BatchNormalization(),
layers.ReLU(),
]
)
self.num_augmentations = num_augmentations
self.temperature = temperature
self.l2_normalize = l2_normalize
self.loss_tracker = keras.metrics.Mean(name="loss")
@property
def metrics(self):
return [self.loss_tracker]
def compute_contrastive_loss(self, feature_vectors, batch_size):
num_augmentations = tf.shape(feature_vectors)[0] // batch_size
if self.l2_normalize:
feature_vectors = tf.math.l2_normalize(feature_vectors, -1)
# The logits shape is [num_augmentations * batch_size, num_augmentations * batch_size].
logits = (
tf.linalg.matmul(feature_vectors, feature_vectors, transpose_b=True)
/ self.temperature
)
# Apply log-max trick for numerical stability.
logits_max = tf.math.reduce_max(logits, axis=1)
logits = logits - logits_max
# The shape of targets is [num_augmentations * batch_size, num_augmentations * batch_size].
# targets is a matrix consits of num_augmentations submatrices of shape [batch_size * batch_size].
# Each [batch_size * batch_size] submatrix is an identity matrix (diagonal entries are ones).
targets = tf.tile(tf.eye(batch_size), [num_augmentations, num_augmentations])
# Compute cross entropy loss
return keras.losses.categorical_crossentropy(
y_true=targets, y_pred=logits, from_logits=True
)
def call(self, inputs):
# Preprocess the input images.
preprocessed = data_preprocessing(inputs)
# Create augmented versions of the images.
augmented = []
for _ in range(self.num_augmentations):
augmented.append(data_augmentation(preprocessed))
augmented = layers.Concatenate(axis=0)(augmented)
# Generate embedding representations of the images.
features = self.encoder(augmented)
# Apply projection head.
return self.projector(features)
def train_step(self, inputs):
batch_size = tf.shape(inputs)[0]
# Run the forward pass and compute the contrastive loss
with tf.GradientTape() as tape:
feature_vectors = self(inputs, training=True)
loss = self.compute_contrastive_loss(feature_vectors, batch_size)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update loss tracker metric
self.loss_tracker.update_state(loss)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def test_step(self, inputs):
batch_size = tf.shape(inputs)[0]
feature_vectors = self(inputs, training=False)
loss = self.compute_contrastive_loss(feature_vectors, batch_size)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
# Create vision encoder.
encoder = create_encoder(representation_dim)
# Create representation learner.
representation_learner = RepresentationLearner(
encoder, projection_units, num_augmentations=2, temperature=0.1
)
# Create a a Cosine decay learning rate scheduler.
lr_scheduler = keras.experimental.CosineDecay(
initial_learning_rate=0.001, decay_steps=500, alpha=0.1
)
# Compile the model.
representation_learner.compile(
optimizer=tfa.optimizers.AdamW(learning_rate=lr_scheduler, weight_decay=0.0001),
)
# Fit the model.
history = representation_learner.fit(
x=x_data,
batch_size=512,
epochs=50, # for better results, increase the number of epochs to 500.
)
Epoch 1/50
118/118 [==============================] - 29s 135ms/step - loss: 41.1135
Epoch 2/50
118/118 [==============================] - 15s 125ms/step - loss: 11.7141
Epoch 3/50
118/118 [==============================] - 15s 125ms/step - loss: 11.1728
Epoch 4/50
118/118 [==============================] - 15s 125ms/step - loss: 10.9717
Epoch 5/50
118/118 [==============================] - 15s 125ms/step - loss: 10.8574
Epoch 6/50
118/118 [==============================] - 15s 125ms/step - loss: 10.9496
Epoch 7/50
118/118 [==============================] - 15s 124ms/step - loss: 10.7493
Epoch 8/50
118/118 [==============================] - 15s 124ms/step - loss: 10.5979
Epoch 9/50
118/118 [==============================] - 15s 124ms/step - loss: 10.4613
Epoch 10/50
118/118 [==============================] - 15s 125ms/step - loss: 10.2900
Epoch 11/50
118/118 [==============================] - 15s 124ms/step - loss: 10.1303
Epoch 12/50
118/118 [==============================] - 15s 124ms/step - loss: 9.9608
Epoch 13/50
118/118 [==============================] - 15s 124ms/step - loss: 9.7788
Epoch 14/50
118/118 [==============================] - 15s 124ms/step - loss: 9.5830
Epoch 15/50
118/118 [==============================] - 15s 124ms/step - loss: 9.4038
Epoch 16/50
118/118 [==============================] - 15s 124ms/step - loss: 9.1887
Epoch 17/50
118/118 [==============================] - 15s 124ms/step - loss: 9.0000
Epoch 18/50
118/118 [==============================] - 15s 124ms/step - loss: 8.7764
Epoch 19/50
118/118 [==============================] - 15s 124ms/step - loss: 8.5784
Epoch 20/50
118/118 [==============================] - 15s 124ms/step - loss: 8.3592
Epoch 21/50
118/118 [==============================] - 15s 124ms/step - loss: 8.2545
Epoch 22/50
118/118 [==============================] - 15s 124ms/step - loss: 8.1171
Epoch 23/50
118/118 [==============================] - 15s 124ms/step - loss: 7.9598
Epoch 24/50
118/118 [==============================] - 15s 124ms/step - loss: 7.8623
Epoch 25/50
118/118 [==============================] - 15s 124ms/step - loss: 7.7169
Epoch 26/50
118/118 [==============================] - 15s 124ms/step - loss: 7.5100
Epoch 27/50
118/118 [==============================] - 15s 124ms/step - loss: 7.5887
Epoch 28/50
118/118 [==============================] - 15s 124ms/step - loss: 7.3511
Epoch 29/50
118/118 [==============================] - 15s 124ms/step - loss: 7.1647
Epoch 30/50
118/118 [==============================] - 15s 124ms/step - loss: 7.1549
Epoch 31/50
118/118 [==============================] - 15s 124ms/step - loss: 7.0462
Epoch 32/50
118/118 [==============================] - 15s 124ms/step - loss: 6.8149
Epoch 33/50
118/118 [==============================] - 15s 124ms/step - loss: 6.6954
Epoch 34/50
118/118 [==============================] - 15s 124ms/step - loss: 6.5354
Epoch 35/50
118/118 [==============================] - 15s 124ms/step - loss: 6.3982
Epoch 36/50
118/118 [==============================] - 15s 124ms/step - loss: 6.4175
Epoch 37/50
118/118 [==============================] - 15s 124ms/step - loss: 6.3820
Epoch 38/50
118/118 [==============================] - 15s 124ms/step - loss: 6.2560
Epoch 39/50
118/118 [==============================] - 15s 124ms/step - loss: 6.1237
Epoch 40/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0485
Epoch 41/50
118/118 [==============================] - 15s 124ms/step - loss: 5.8846
Epoch 42/50
118/118 [==============================] - 15s 124ms/step - loss: 5.7548
Epoch 43/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0794
Epoch 44/50
118/118 [==============================] - 15s 124ms/step - loss: 5.9023
Epoch 45/50
118/118 [==============================] - 15s 124ms/step - loss: 5.9548
Epoch 46/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0809
Epoch 47/50
118/118 [==============================] - 15s 124ms/step - loss: 5.6123
Epoch 48/50
118/118 [==============================] - 15s 124ms/step - loss: 5.5667
Epoch 49/50
118/118 [==============================] - 15s 124ms/step - loss: 5.4573
Epoch 50/50
118/118 [==============================] - 15s 124ms/step - loss: 5.4597
Plot training loss
plt.plot(history.history["loss"])
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()
batch_size = 500
# Get the feature vector representations of the images.
feature_vectors = encoder.predict(x_data, batch_size=batch_size, verbose=1)
# Normalize the feature vectores.
feature_vectors = tf.math.l2_normalize(feature_vectors, -1)
120/120 [==============================] - 4s 18ms/step
neighbours = []
num_batches = feature_vectors.shape[0] // batch_size
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * batch_size
end_idx = start_idx + batch_size
current_batch = feature_vectors[start_idx:end_idx]
# Compute the dot similarity.
similarities = tf.linalg.matmul(current_batch, feature_vectors, transpose_b=True)
# Get the indices of most similar vectors.
_, indices = tf.math.top_k(similarities, k=k_neighbours + 1, sorted=True)
# Add the indices to the neighbours.
neighbours.append(indices[..., 1:])
neighbours = np.reshape(np.array(neighbours), (-1, k_neighbours))
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 304.24it/s]
Let's display some neighbors on each row
nrows = 4
ncols = k_neighbours + 1
plt.figure(figsize=(12, 12))
position = 1
for _ in range(nrows):
anchor_idx = np.random.choice(range(x_data.shape[0]))
neighbour_indicies = neighbours[anchor_idx]
indices = [anchor_idx] + neighbour_indicies.tolist()
for j in range(ncols):
plt.subplot(nrows, ncols, position)
plt.imshow(x_data[indices[j]].astype("uint8"))
plt.title(classes[y_data[indices[j]][0]])
plt.axis("off")
position += 1
You notice that images on each row are visually similar, and belong to similar classes.
This loss tries to make sure that neighbours have the same clustering assignments.
class ClustersConsistencyLoss(keras.losses.Loss):
def __init__(self):
super(ClustersConsistencyLoss, self).__init__()
def __call__(self, target, similarity, sample_weight=None):
# Set targets to be ones.
target = tf.ones_like(similarity)
# Compute cross entropy loss.
loss = keras.losses.binary_crossentropy(
y_true=target, y_pred=similarity, from_logits=True
)
return tf.math.reduce_mean(loss)
This loss tries to make sure that cluster distribution is roughly uniformed, to avoid assigning most of the instances to one cluster.
class ClustersEntropyLoss(keras.losses.Loss):
def __init__(self, entropy_loss_weight=1.0):
super(ClustersEntropyLoss, self).__init__()
self.entropy_loss_weight = entropy_loss_weight
def __call__(self, target, cluster_probabilities, sample_weight=None):
# Ideal entropy = log(num_clusters).
num_clusters = tf.cast(tf.shape(cluster_probabilities)[-1], tf.dtypes.float32)
target = tf.math.log(num_clusters)
# Compute the overall clusters distribution.
cluster_probabilities = tf.math.reduce_mean(cluster_probabilities, axis=0)
# Replacing zero probabilities - if any - with a very small value.
cluster_probabilities = tf.clip_by_value(
cluster_probabilities, clip_value_min=1e-8, clip_value_max=1.0
)
# Compute the entropy over the clusters.
entropy = -tf.math.reduce_sum(
cluster_probabilities * tf.math.log(cluster_probabilities)
)
# Compute the difference between the target and the actual.
loss = target - entropy
return loss
This model takes a raw image as an input, generated its feature vector using the trained encoder, and produces a probability distribution of the clusters given the feature vector as the cluster assignments.
def create_clustering_model(encoder, num_clusters, name=None):
inputs = keras.Input(shape=input_shape)
# Preprocess the input images.
preprocessed = data_preprocessing(inputs)
# Apply data augmentation to the images.
augmented = data_augmentation(preprocessed)
# Generate embedding representations of the images.
features = encoder(augmented)
# Assign the images to clusters.
outputs = layers.Dense(units=num_clusters, activation="softmax")(features)
# Create the model.
model = keras.Model(inputs=inputs, outputs=outputs, name=name)
return model
This model receives the input anchor
image and its neighbours
, produces the clusters
assignments for them using the clustering_model
, and produces two outputs:
1. similarity
: the similarity between the cluster assignments of the anchor
image and
its neighbours
. This output is fed to the ClustersConsistencyLoss
.
2. anchor_clustering
: cluster assignments of the anchor
images. This is fed to the ClustersEntropyLoss
.
def create_clustering_learner(clustering_model):
anchor = keras.Input(shape=input_shape, name="anchors")
neighbours = keras.Input(
shape=tuple([k_neighbours]) + input_shape, name="neighbours"
)
# Changes neighbours shape to [batch_size * k_neighbours, width, height, channels]
neighbours_reshaped = tf.reshape(neighbours, shape=tuple([-1]) + input_shape)
# anchor_clustering shape: [batch_size, num_clusters]
anchor_clustering = clustering_model(anchor)
# neighbours_clustering shape: [batch_size * k_neighbours, num_clusters]
neighbours_clustering = clustering_model(neighbours_reshaped)
# Convert neighbours_clustering shape to [batch_size, k_neighbours, num_clusters]
neighbours_clustering = tf.reshape(
neighbours_clustering,
shape=(-1, k_neighbours, tf.shape(neighbours_clustering)[-1]),
)
# similarity shape: [batch_size, 1, k_neighbours]
similarity = tf.linalg.einsum(
"bij,bkj->bik", tf.expand_dims(anchor_clustering, axis=1), neighbours_clustering
)
# similarity shape: [batch_size, k_neighbours]
similarity = layers.Lambda(lambda x: tf.squeeze(x, axis=1), name="similarity")(
similarity
)
# Create the model.
model = keras.Model(
inputs=[anchor, neighbours],
outputs=[similarity, anchor_clustering],
name="clustering_learner",
)
return model
# If tune_encoder_during_clustering is set to False,
# then freeze the encoder weights.
for layer in encoder.layers:
layer.trainable = tune_encoder_during_clustering
# Create the clustering model and learner.
clustering_model = create_clustering_model(encoder, num_clusters, name="clustering")
clustering_learner = create_clustering_learner(clustering_model)
# Instantiate the model losses.
losses = [ClustersConsistencyLoss(), ClustersEntropyLoss(entropy_loss_weight=5)]
# Create the model inputs and labels.
inputs = {"anchors": x_data, "neighbours": tf.gather(x_data, neighbours)}
labels = tf.ones(shape=(x_data.shape[0]))
# Compile the model.
clustering_learner.compile(
optimizer=tfa.optimizers.AdamW(learning_rate=0.0005, weight_decay=0.0001),
loss=losses,
)
# Begin training the model.
clustering_learner.fit(x=inputs, y=labels, batch_size=512, epochs=50)
Epoch 1/50
118/118 [==============================] - 20s 95ms/step - loss: 0.6655 - similarity_loss: 0.6642 - clustering_loss: 0.0013
Epoch 2/50
118/118 [==============================] - 10s 86ms/step - loss: 0.6361 - similarity_loss: 0.6325 - clustering_loss: 0.0036
Epoch 3/50
118/118 [==============================] - 10s 85ms/step - loss: 0.6129 - similarity_loss: 0.6070 - clustering_loss: 0.0059
Epoch 4/50
118/118 [==============================] - 10s 85ms/step - loss: 0.6005 - similarity_loss: 0.5930 - clustering_loss: 0.0075
Epoch 5/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5923 - similarity_loss: 0.5849 - clustering_loss: 0.0074
Epoch 6/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5879 - similarity_loss: 0.5795 - clustering_loss: 0.0084
Epoch 7/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5841 - similarity_loss: 0.5754 - clustering_loss: 0.0087
Epoch 8/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5817 - similarity_loss: 0.5733 - clustering_loss: 0.0084
Epoch 9/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5811 - similarity_loss: 0.5717 - clustering_loss: 0.0094
Epoch 10/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5797 - similarity_loss: 0.5697 - clustering_loss: 0.0100
Epoch 11/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5767 - similarity_loss: 0.5676 - clustering_loss: 0.0091
Epoch 12/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5771 - similarity_loss: 0.5667 - clustering_loss: 0.0104
Epoch 13/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5755 - similarity_loss: 0.5661 - clustering_loss: 0.0094
Epoch 14/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5746 - similarity_loss: 0.5653 - clustering_loss: 0.0093
Epoch 15/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5743 - similarity_loss: 0.5640 - clustering_loss: 0.0103
Epoch 16/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5738 - similarity_loss: 0.5636 - clustering_loss: 0.0102
Epoch 17/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5732 - similarity_loss: 0.5627 - clustering_loss: 0.0106
Epoch 18/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5723 - similarity_loss: 0.5621 - clustering_loss: 0.0102
Epoch 19/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5711 - similarity_loss: 0.5615 - clustering_loss: 0.0096
Epoch 20/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5693 - similarity_loss: 0.5596 - clustering_loss: 0.0097
Epoch 21/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5699 - similarity_loss: 0.5600 - clustering_loss: 0.0099
Epoch 22/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5694 - similarity_loss: 0.5592 - clustering_loss: 0.0102
Epoch 23/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5703 - similarity_loss: 0.5595 - clustering_loss: 0.0108
Epoch 24/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5687 - similarity_loss: 0.5587 - clustering_loss: 0.0101
Epoch 25/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5688 - similarity_loss: 0.5585 - clustering_loss: 0.0103
Epoch 26/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5690 - similarity_loss: 0.5583 - clustering_loss: 0.0108
Epoch 27/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5572 - clustering_loss: 0.0107
Epoch 28/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5681 - similarity_loss: 0.5573 - clustering_loss: 0.0108
Epoch 29/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5682 - similarity_loss: 0.5572 - clustering_loss: 0.0111
Epoch 30/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5675 - similarity_loss: 0.5571 - clustering_loss: 0.0104
Epoch 31/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5562 - clustering_loss: 0.0116
Epoch 32/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5663 - similarity_loss: 0.5554 - clustering_loss: 0.0109
Epoch 33/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5665 - similarity_loss: 0.5556 - clustering_loss: 0.0109
Epoch 34/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5568 - clustering_loss: 0.0111
Epoch 35/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5680 - similarity_loss: 0.5563 - clustering_loss: 0.0117
Epoch 36/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5665 - similarity_loss: 0.5553 - clustering_loss: 0.0112
Epoch 37/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5674 - similarity_loss: 0.5556 - clustering_loss: 0.0118
Epoch 38/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5648 - similarity_loss: 0.5543 - clustering_loss: 0.0105
Epoch 39/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5653 - similarity_loss: 0.5549 - clustering_loss: 0.0103
Epoch 40/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5656 - similarity_loss: 0.5544 - clustering_loss: 0.0113
Epoch 41/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5644 - similarity_loss: 0.5542 - clustering_loss: 0.0102
Epoch 42/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5658 - similarity_loss: 0.5540 - clustering_loss: 0.0118
Epoch 43/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5655 - similarity_loss: 0.5539 - clustering_loss: 0.0116
Epoch 44/50
118/118 [==============================] - 10s 87ms/step - loss: 0.5662 - similarity_loss: 0.5543 - clustering_loss: 0.0119
Epoch 45/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5651 - similarity_loss: 0.5537 - clustering_loss: 0.0114
Epoch 46/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5635 - similarity_loss: 0.5534 - clustering_loss: 0.0101
Epoch 47/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5633 - similarity_loss: 0.5529 - clustering_loss: 0.0103
Epoch 48/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5643 - similarity_loss: 0.5526 - clustering_loss: 0.0117
Epoch 49/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5653 - similarity_loss: 0.5532 - clustering_loss: 0.0121
Epoch 50/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5641 - similarity_loss: 0.5525 - clustering_loss: 0.0117
<tensorflow.python.keras.callbacks.History at 0x7f1da373ea10>
Plot training loss
plt.plot(history.history["loss"])
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()
# Get the cluster probability distribution of the input images.
clustering_probs = clustering_model.predict(x_data, batch_size=batch_size, verbose=1)
# Get the cluster of the highest probability.
cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
# Store the clustering confidence.
# Images with the highest clustering confidence are considered the 'prototypes'
# of the clusters.
cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()
120/120 [==============================] - 3s 20ms/step
Let's compute the cluster sizes
clusters = defaultdict(list)
for idx, c in enumerate(cluster_assignments):
clusters[c].append((idx, cluster_confidence[idx]))
for c in range(num_clusters):
print("cluster", c, ":", len(clusters[c]))
cluster 0 : 4132
cluster 1 : 4057
cluster 2 : 1713
cluster 3 : 2801
cluster 4 : 2511
cluster 5 : 2655
cluster 6 : 2517
cluster 7 : 4493
cluster 8 : 3687
cluster 9 : 1716
cluster 10 : 3397
cluster 11 : 3606
cluster 12 : 3325
cluster 13 : 4010
cluster 14 : 2188
cluster 15 : 3278
cluster 16 : 1902
cluster 17 : 1858
cluster 18 : 3828
cluster 19 : 2326
Notice that the clusters have roughly balanced sizes.
Display the prototypes—instances with the highest clustering confidence—of each cluster:
num_images = 8
plt.figure(figsize=(15, 15))
position = 1
for c in range(num_clusters):
cluster_instances = sorted(clusters[c], key=lambda kv: kv[1], reverse=True)
for j in range(num_images):
image_idx = cluster_instances[j][0]
plt.subplot(num_clusters, num_images, position)
plt.imshow(x_data[image_idx].astype("uint8"))
plt.title(classes[y_data[image_idx][0]])
plt.axis("off")
position += 1
First, we assign a label for each cluster based on the majority label of its images. Then, we compute the accuracy of each cluster by dividing the number of image with the majority label by the size of the cluster.
cluster_label_counts = dict()
for c in range(num_clusters):
cluster_label_counts[c] = [0] * num_classes
instances = clusters[c]
for i, _ in instances:
cluster_label_counts[c][y_data[i][0]] += 1
cluster_label_idx = np.argmax(cluster_label_counts[c])
correct_count = np.max(cluster_label_counts[c])
cluster_size = len(clusters[c])
accuracy = (
np.round((correct_count / cluster_size) * 100, 2) if cluster_size > 0 else 0
)
cluster_label = classes[cluster_label_idx]
print("cluster", c, "label is:", cluster_label, " - accuracy:", accuracy, "%")
cluster 0 label is: frog - accuracy: 23.11 %
cluster 1 label is: truck - accuracy: 23.56 %
cluster 2 label is: bird - accuracy: 29.01 %
cluster 3 label is: dog - accuracy: 16.67 %
cluster 4 label is: truck - accuracy: 27.8 %
cluster 5 label is: ship - accuracy: 36.91 %
cluster 6 label is: deer - accuracy: 27.89 %
cluster 7 label is: dog - accuracy: 23.84 %
cluster 8 label is: airplane - accuracy: 21.7 %
cluster 9 label is: bird - accuracy: 22.38 %
cluster 10 label is: automobile - accuracy: 24.76 %
cluster 11 label is: automobile - accuracy: 24.15 %
cluster 12 label is: cat - accuracy: 17.44 %
cluster 13 label is: truck - accuracy: 23.44 %
cluster 14 label is: ship - accuracy: 31.67 %
cluster 15 label is: airplane - accuracy: 41.06 %
cluster 16 label is: deer - accuracy: 22.77 %
cluster 17 label is: airplane - accuracy: 15.18 %
cluster 18 label is: frog - accuracy: 33.31 %
cluster 19 label is: deer - accuracy: 18.7 %
To improve the accuracy results, you can: 1) increase the number of epochs in the representation learning and the clustering phases; 2) allow the encoder weights to be tuned during the clustering phase; and 3) perform a final fine-tuning step through self-labeling, as described in the original SCAN paper. Note that unsupervised image clustering techniques are not expected to outperform the accuracy of supervised image classification techniques, rather showing that they can learn the semantics of the images and group them into clusters that are similar to their original classes.