Author: David Griffiths
Date created: 2020/05/25
Last modified: 2020/05/26
Description: Implementation of PointNet for ModelNet10 classification.
Classification, detection and segmentation of unordered 3D point sets i.e. point clouds is a core problem in computer vision. This example implements the seminal point cloud deep learning paper PointNet (Qi et al., 2017). For a detailed intoduction on PointNet see this blog post.
If using colab first install trimesh with !pip install trimesh
.
import os
import glob
import trimesh
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
tf.random.set_seed(1234)
We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40 dataset. First download the data:
DATA_DIR = tf.keras.utils.get_file(
"modelnet.zip",
"http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",
extract=True,
)
DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")
Downloading data from http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
473407488/473402300 [==============================] - 13s 0us/step
We can use the trimesh
package to read and visualize the .off
mesh files.
mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))
mesh.show()
To convert a mesh file to a point cloud we first need to sample points on the mesh
surface. .sample()
performs a unifrom random sampling. Here we sample at 2048 locations
and visualize in matplotlib
.
points = mesh.sample(2048)
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
ax.set_axis_off()
plt.show()
To generate a tf.data.Dataset()
we need to first parse through the ModelNet data
folders. Each mesh is loaded and sampled into a point cloud before being added to a
standard python list and converted to a numpy
array. We also store the current
enumerate index value as the object label and use a dictionary to recall this later.
def parse_dataset(num_points=2048):
train_points = []
train_labels = []
test_points = []
test_labels = []
class_map = {}
folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))
for i, folder in enumerate(folders):
print("processing class: {}".format(os.path.basename(folder)))
# store folder name with ID so we can retrieve later
class_map[i] = folder.split("/")[-1]
# gather all files
train_files = glob.glob(os.path.join(folder, "train/*"))
test_files = glob.glob(os.path.join(folder, "test/*"))
for f in train_files:
train_points.append(trimesh.load(f).sample(num_points))
train_labels.append(i)
for f in test_files:
test_points.append(trimesh.load(f).sample(num_points))
test_labels.append(i)
return (
np.array(train_points),
np.array(test_points),
np.array(train_labels),
np.array(test_labels),
class_map,
)
Set the number of points to sample and batch size and parse the dataset. This can take ~5minutes to complete.
NUM_POINTS = 2048
NUM_CLASSES = 10
BATCH_SIZE = 32
train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(
NUM_POINTS
)
processing class: bathtub
processing class: desk
processing class: monitor
processing class: sofa
processing class: chair
processing class: toilet
processing class: dresser
processing class: table
processing class: bed
processing class: night_stand
Our data can now be read into a tf.data.Dataset()
object. We set the shuffle buffer
size to the entire size of the dataset as prior to this the data is ordered by class.
Data augmentation is important when working with point cloud data. We create a
augmentation function to jitter and shuffle the train dataset.
def augment(points, label):
# jitter points
points += tf.random.uniform(points.shape, -0.005, 0.005, dtype=tf.float64)
# shuffle points
points = tf.random.shuffle(points)
return points, label
train_dataset = tf.data.Dataset.from_tensor_slices((train_points, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_points, test_labels))
train_dataset = train_dataset.shuffle(len(train_points)).map(augment).batch(BATCH_SIZE)
test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)
Each convolution and fully-connected layer (with exception for end layers) consits of Convolution / Dense -> Batch Normalization -> ReLU Activation.
def conv_bn(x, filters):
x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
def dense_bn(x, filters):
x = layers.Dense(filters)(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
PointNet consists of two core components. The primary MLP network, and the transformer net (T-net). The T-net aims to learn an affine transformation matrix by its own mini network. The T-net is used twice. The first time to transform the input features (n, 3) into a canonical representation. The second is an affine transformation for alignment in feature space (n, 3). As per the original paper we constrain the transformation to be close to an orthogonal matrix (i.e. ||X*X^T - I|| = 0).
class OrthogonalRegularizer(keras.regularizers.Regularizer):
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.eye = tf.eye(num_features)
def __call__(self, x):
x = tf.reshape(x, (-1, self.num_features, self.num_features))
xxt = tf.tensordot(x, x, axes=(2, 2))
xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))
We can then define a general function to build T-net layers.
def tnet(inputs, num_features):
# Initalise bias as the indentity matrix
bias = keras.initializers.Constant(np.eye(num_features).flatten())
reg = OrthogonalRegularizer(num_features)
x = conv_bn(inputs, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = dense_bn(x, 128)
x = layers.Dense(
num_features * num_features,
kernel_initializer="zeros",
bias_initializer=bias,
activity_regularizer=reg,
)(x)
feat_T = layers.Reshape((num_features, num_features))(x)
# Apply affine transformation to input features
return layers.Dot(axes=(2, 1))([inputs, feat_T])
The main network can be then implemented in the same manner where the t-net mini models can be dropped in a layers in the graph. Here we replicate the network architecture published in the original paper but with half the number of weights at each layer as we are using the smaller 10 class ModelNet dataset.
inputs = keras.Input(shape=(NUM_POINTS, 3))
x = tnet(inputs, 3)
x = conv_bn(x, 32)
x = conv_bn(x, 32)
x = tnet(x, 32)
x = conv_bn(x, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = layers.Dropout(0.3)(x)
x = dense_bn(x, 128)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
model.summary()
Model: "pointnet"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 2048, 3)] 0
__________________________________________________________________________________________________
conv1d (Conv1D) (None, 2048, 32) 128 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 2048, 32) 128 conv1d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 2048, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 2048, 64) 2112 activation[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 2048, 64) 256 conv1d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 2048, 64) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 2048, 512) 33280 activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 2048, 512) 2048 conv1d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 2048, 512) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
global_max_pooling1d (GlobalMax (None, 512) 0 activation_2[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 256) 131328 global_max_pooling1d[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 256) 1024 dense[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 256) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 128) 32896 activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128) 512 dense_1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 128) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 9) 1161 activation_4[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 3, 3) 0 dense_2[0][0]
__________________________________________________________________________________________________
dot (Dot) (None, 2048, 3) 0 input_1[0][0]
reshape[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 2048, 32) 128 dot[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 2048, 32) 128 conv1d_3[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 2048, 32) 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 2048, 32) 1056 activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 2048, 32) 128 conv1d_4[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 2048, 32) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D) (None, 2048, 32) 1056 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 2048, 32) 128 conv1d_5[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 2048, 32) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, 2048, 64) 2112 activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 2048, 64) 256 conv1d_6[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 2048, 64) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, 2048, 512) 33280 activation_8[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2048, 512) 2048 conv1d_7[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 2048, 512) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM (None, 512) 0 activation_9[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 256) 131328 global_max_pooling1d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 256) 1024 dense_3[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 256) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 128) 512 dense_4[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 128) 0 batch_normalization_11[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 1024) 132096 activation_11[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 32, 32) 0 dense_5[0][0]
__________________________________________________________________________________________________
dot_1 (Dot) (None, 2048, 32) 0 activation_6[0][0]
reshape_1[0][0]
__________________________________________________________________________________________________
conv1d_8 (Conv1D) (None, 2048, 32) 1056 dot_1[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 2048, 32) 128 conv1d_8[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 2048, 32) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv1d_9 (Conv1D) (None, 2048, 64) 2112 activation_12[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 2048, 64) 256 conv1d_9[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 2048, 64) 0 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv1d_10 (Conv1D) (None, 2048, 512) 33280 activation_13[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 2048, 512) 2048 conv1d_10[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 2048, 512) 0 batch_normalization_14[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_2 (GlobalM (None, 512) 0 activation_14[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 256) 131328 global_max_pooling1d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 256) 1024 dense_6[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 256) 0 batch_normalization_15[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 256) 0 activation_15[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 128) 32896 dropout[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128) 512 dense_7[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 128) 0 batch_normalization_16[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 128) 0 activation_16[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 10) 1290 dropout_1[0][0]
==================================================================================================
Total params: 748,979
Trainable params: 742,899
Non-trainable params: 6,080
__________________________________________________________________________________________________
Once the model is defined it can be trained like any other standard classification model
using .compile()
and .fit()
.
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["sparse_categorical_accuracy"],
)
model.fit(train_dataset, epochs=20, validation_data=test_dataset)
Epoch 1/20
125/125 [==============================] - 28s 221ms/step - loss: 3.5897 - sparse_categorical_accuracy: 0.2724 - val_loss: 5804697916006203392.0000 - val_sparse_categorical_accuracy: 0.3073
Epoch 2/20
125/125 [==============================] - 27s 215ms/step - loss: 3.1970 - sparse_categorical_accuracy: 0.3443 - val_loss: 836343949164544.0000 - val_sparse_categorical_accuracy: 0.3425
Epoch 3/20
125/125 [==============================] - 27s 215ms/step - loss: 2.8959 - sparse_categorical_accuracy: 0.4260 - val_loss: 15107376738729984.0000 - val_sparse_categorical_accuracy: 0.3084
Epoch 4/20
125/125 [==============================] - 27s 215ms/step - loss: 2.7148 - sparse_categorical_accuracy: 0.4939 - val_loss: 6823221.0000 - val_sparse_categorical_accuracy: 0.3304
Epoch 5/20
125/125 [==============================] - 27s 215ms/step - loss: 2.5500 - sparse_categorical_accuracy: 0.5560 - val_loss: 675110905872323182592.0000 - val_sparse_categorical_accuracy: 0.4493
Epoch 6/20
125/125 [==============================] - 27s 215ms/step - loss: 2.3595 - sparse_categorical_accuracy: 0.6081 - val_loss: 600389124096.0000 - val_sparse_categorical_accuracy: 0.5749
Epoch 7/20
125/125 [==============================] - 27s 215ms/step - loss: 2.2485 - sparse_categorical_accuracy: 0.6394 - val_loss: 680423464582760103936.0000 - val_sparse_categorical_accuracy: 0.4912
Epoch 8/20
125/125 [==============================] - 27s 215ms/step - loss: 2.1945 - sparse_categorical_accuracy: 0.6575 - val_loss: 44108689408.0000 - val_sparse_categorical_accuracy: 0.6410
Epoch 9/20
125/125 [==============================] - 27s 215ms/step - loss: 2.1318 - sparse_categorical_accuracy: 0.6725 - val_loss: 873314112.0000 - val_sparse_categorical_accuracy: 0.6112
Epoch 10/20
125/125 [==============================] - 27s 215ms/step - loss: 2.0140 - sparse_categorical_accuracy: 0.7018 - val_loss: 13168980992.0000 - val_sparse_categorical_accuracy: 0.6784
Epoch 11/20
125/125 [==============================] - 27s 215ms/step - loss: 1.9929 - sparse_categorical_accuracy: 0.7056 - val_loss: 36888236785664.0000 - val_sparse_categorical_accuracy: 0.6586
Epoch 12/20
125/125 [==============================] - 27s 215ms/step - loss: 1.9542 - sparse_categorical_accuracy: 0.7166 - val_loss: 85375.9844 - val_sparse_categorical_accuracy: 0.7026
Epoch 13/20
125/125 [==============================] - 27s 215ms/step - loss: 1.8648 - sparse_categorical_accuracy: 0.7447 - val_loss: 7.7962 - val_sparse_categorical_accuracy: 0.5441
Epoch 14/20
125/125 [==============================] - 27s 215ms/step - loss: 1.9016 - sparse_categorical_accuracy: 0.7444 - val_loss: 66469.9062 - val_sparse_categorical_accuracy: 0.6134
Epoch 15/20
125/125 [==============================] - 27s 215ms/step - loss: 1.8003 - sparse_categorical_accuracy: 0.7695 - val_loss: 519227186348032.0000 - val_sparse_categorical_accuracy: 0.6949
Epoch 16/20
125/125 [==============================] - 27s 215ms/step - loss: 1.8019 - sparse_categorical_accuracy: 0.7702 - val_loss: 5263462156149188460544.0000 - val_sparse_categorical_accuracy: 0.6520
Epoch 17/20
125/125 [==============================] - 27s 215ms/step - loss: 1.7177 - sparse_categorical_accuracy: 0.7903 - val_loss: 142240048.0000 - val_sparse_categorical_accuracy: 0.7941
Epoch 18/20
125/125 [==============================] - 27s 216ms/step - loss: 1.7548 - sparse_categorical_accuracy: 0.7855 - val_loss: 2.6049 - val_sparse_categorical_accuracy: 0.5022
Epoch 19/20
125/125 [==============================] - 27s 215ms/step - loss: 1.7101 - sparse_categorical_accuracy: 0.8003 - val_loss: 1152819181305987072.0000 - val_sparse_categorical_accuracy: 0.7753
Epoch 20/20
125/125 [==============================] - 27s 215ms/step - loss: 1.6812 - sparse_categorical_accuracy: 0.8176 - val_loss: 12854714433536.0000 - val_sparse_categorical_accuracy: 0.7390
<tensorflow.python.keras.callbacks.History at 0x7f07e5dd3940>
We can use matplotlib to visualize our trained model performance.
data = test_dataset.take(1)
points, labels = list(data)[0]
points = points[:8, ...]
labels = labels[:8, ...]
# run test data through model
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
# plot points with predicted class and label
fig = plt.figure(figsize=(15, 10))
for i in range(8):
ax = fig.add_subplot(2, 4, i + 1, projection="3d")
ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2])
ax.set_title(
"pred: {:}, label: {:}".format(
CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]]
)
)
ax.set_axis_off()
plt.show()