Code examples / Timeseries / Electroencephalogram Signal Classification for Brain-Computer Interface

Electroencephalogram Signal Classification for Brain-Computer Interface

Author: Okba Bekhelifi
Date created: 2025/01/08
Last modified: 2025/01/08
Description: A Transformer based classification for EEG signal for BCI.

ⓘ This example uses Keras 2

View in Colab GitHub source

Introduction

This tutorial will explain how to build a Transformer based Neural Network to classify Brain-Computer Interface (BCI) Electroencephalograpy (EEG) data recorded in a Steady-State Visual Evoked Potentials (SSVEPs) experiment for the application of a brain-controlled speller.

The tutorial reproduces an experiment from the SSVEPFormer study [1] ( arXiv preprint / Peer-Reviewed paper ). This model was the first Transformer based model to be introduced for SSVEP data classification, we will test it on the Nakanishi et al. [2] public dataset as dataset 1 from the paper.

The process follows an inter-subject classification experiment. Given N subject data in the dataset, the training data partition contains data from N-1 subject and the remaining single subject data is used for testing. the training set does not contain any sample from the testing subject. This way we construct a true subject-independent model. We keep the same parameters and settings as the original paper in all processing operations from preprocessing to training.

The tutorial begins with a quick BCI and dataset description then, we go through the technicalities following these sections: - Setup, and imports. - Dataset download and extraction. - Data preprocessing: EEG data filtering, segmentation and visualization of raw and filtered data, and frequency response for a well performing participant. - Layers and model creation. - Evaluation: a single participant data classification as an example then the total participants data classification. - Visulization: we show the results of training and inference times comparison among the Keras 3 available backends (JAX, Tensorflow, and PyTorch) on three different GPUs. - Conclusion: final discussion and remarks.

Dataset description


BCI and SSVEP:

A BCI offers the ability to communicate using only brain activity, this can be achieved through exogenous stimuli that generate specific responses indicating the intent of the subject. the responses are elicited when the user focuses their attention on the target stimulus. We can use visual stimuli by presenting the subject with a set of options typically on a monitor as a grid to select one command at a time. Each stimulus will flicker following a fixed frequency and phase, the resulting EEG recorded at occipital and occipito-parietal areas of the cortex (visual cortex) will have higher power in the associated frequency with the stimulus where the subject was looking at. This type of BCI paradigm is called the Steady-State Visual Evoked Potentials (SSVEPs) and became widely used for multiple application due to its reliability and high perfromance in classification and rapidity as a 1-second of EEG is sufficient making a command. Other types of brain responses exists and do not require external stimulations, however they are less reliable. Demo video

This tutorials uses the 12 commands (class) public SSVEP dataset [2] with the following interface emulating a phone dialing numbers. dataset

The dataset was recorded with 10 participants, each faced the above 12 SSVEP stimuli (A). The stimulation frequencies ranged from 9.25Hz to 14.75 Hz with 0.5Hz step, and phases ranged from 0 to 1.5 π with 0.5 π step for each row.(B). The EEG signal was acquired with 8 electrodes (channels) (PO7, PO3, POz, PO4, PO8, O1, Oz, O2) sampling frequency was 2048 Hz then the stored data were downsampled to 256 Hz. The subjects completed 15 blocks of recordings, each consisted of 12 random ordered stimulations (1 for each class) of 4 seconds each. In total, each subject conducted 180 trials.

Setup


Select JAX backend

import os

os.environ["KERAS_BACKEND"] = "jax"

Install dependencies

!pip install -q numpy
!pip install -q scipy
!pip install -q matplotlib

Imports

# deep learning libraries
from keras import backend as K
from keras import layers
import keras

# visualization and signal processing imports
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from scipy.signal import butter, filtfilt
from scipy.io import loadmat

# setting the backend, seed and Keras channel format
K.set_image_data_format("channels_first")
keras.utils.set_random_seed(42)

Download and extract dataset


Nakanishi et. al 2015 DataSet Repo

!curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
!unzip cca_ssvep.zip
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0 0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0

0 145M 0 49152 0 0 40897 0 1:02:11 0:00:01 1:02:10 40891

1 145M 1 2480k 0 0 1140k 0 0:02:10 0:00:02 0:02:08 1140k

9 145M 9 13.4M 0 0 4371k 0 0:00:34 0:00:03 0:00:31 4371k

17 145M 17 25.9M 0 0 6201k 0 0:00:24 0:00:04 0:00:20 6200k

25 145M 25 37.2M 0 0 7398k 0 0:00:20 0:00:05 0:00:15 7632k

33 145M 33 48.9M 0 0 8052k 0 0:00:18 0:00:06 0:00:12 9972k

41 145M 41 60.3M 0 0 8586k 0 0:00:17 0:00:07 0:00:10 11.5M

49 145M 49 71.9M 0 0 9014k 0 0:00:16 0:00:08 0:00:08 11.6M

57 145M 57 84.2M 0 0 9423k 0 0:00:15 0:00:09 0:00:06 11.9M

66 145M 66 96.5M 0 0 9709k 0 0:00:15 0:00:10 0:00:05 11.8M

74 145M 74 108M 0 0 9938k 0 0:00:14 0:00:11 0:00:03 12.0M

82 145M 82 119M 0 0 9.8M 0 0:00:14 0:00:12 0:00:02 11.8M

90 145M 90 131M 0 0 9.9M 0 0:00:14 0:00:13 0:00:01 11.8M

98 145M 98 142M 0 0 10.0M 0 0:00:14 0:00:14 –:–:– 11.7M 100 145M 100 145M 0 0 10.1M 0 0:00:14 0:00:14 –:–:– 11.7M

Archive:  cca_ssvep.zip
   creating: cca_ssvep/
  inflating: cca_ssvep/s4.mat        
  inflating: cca_ssvep/s5.mat        
  inflating: cca_ssvep/s3.mat        
  inflating: cca_ssvep/s7.mat        
  inflating: cca_ssvep/chan_locs.pdf  
  inflating: cca_ssvep/readme.txt    
  inflating: cca_ssvep/s2.mat        
  inflating: cca_ssvep/s8.mat        
  inflating: cca_ssvep/s10.mat       
  inflating: cca_ssvep/s9.mat        
  inflating: cca_ssvep/s6.mat        
  inflating: cca_ssvep/s1.mat        

Pre-Processing

The preprocessing steps followed are first to read the EEG data for each subject, then to filter the raw data in a frequency interval where most useful information lies, then we select a fixed duration of signal starting from the onset of the stimulation (due to latency delay caused by the visual system we start we add 135 milliseconds to the stimulation onset). Lastly, all subjects data are concatenated in a single Tensor of the shape: [subjects x samples x channels x trials]. The data labels are also concatenated following the order of the trials in the experiments and will be a matrix of the shape [subjects x trials] (here by channels we mean electrodes, we use this notation throughout the tutorial).

def raw_signal(folder, fs=256, duration=1.0, onset=0.135):
    """selecting a 1-second segment of the raw EEG signal for
    subject 1.
    """
    onset = 38 + int(onset * fs)
    end = int(duration * fs)
    data = loadmat(f"{folder}/s1.mat")
    # samples, channels, trials, targets
    eeg = data["eeg"].transpose((2, 1, 3, 0))
    # segment data
    eeg = eeg[onset : onset + end, :, :, :]
    return eeg


def segment_eeg(
    folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135
):
    """Filtering and segmenting EEG signals for all subjects."""
    n_subejects = 10
    onset = 38 + int(onset * fs)
    end = int(duration * fs)
    X, Y = [], []  # empty data and labels

    for subj in range(1, n_subejects + 1):
        data = loadmat(f"{data_folder}/s{subj}.mat")
        # samples, channels, trials, targets
        eeg = data["eeg"].transpose((2, 1, 3, 0))
        # filter data
        eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
        # segment data
        eeg = eeg[onset : onset + end, :, :, :]
        # reshape labels
        samples, channels, blocks, targets = eeg.shape
        y = np.tile(np.arange(1, targets + 1), (blocks, 1))
        y = y.reshape((1, blocks * targets), order="F")

        X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))
        Y.append(y)

    X = np.array(X, dtype=np.float32, order="F")
    Y = np.array(Y, dtype=np.float32).squeeze()

    return X, Y


def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):
    """Filter EEG signal using a zero-phase IIR filter"""
    B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
    return filtfilt(B, A, data, axis=0)

Segment data into epochs

data_folder = os.path.abspath("./cca_ssvep")
band = [8, 64]  # low-frequency / high-frequency cutoffS
order = 4  # filter order
fs = 256  # sampling frequency
duration = 1.0  # 1 second

# raw signal
X_raw = raw_signal(data_folder, fs=fs, duration=duration)
print(
    f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"
)

# segmented signal
X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
print(
    f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"
)
print(f"data labels (Y) shape:        {Y.shape} [Subject x Trials]")

samples = X.shape[1]
time = np.linspace(0.0, samples / fs, samples) * 1000
A single subject raw EEG (X_raw) shape: (256, 8, 15, 12) [Samples x Channels x Blocks x Targets]

Full training data (X) shape: (10, 256, 8, 180) [Subject x Samples x Channels x Trials]
data labels (Y) shape:        (10, 180) [Subject x Trials]

Visualize EEG signal


EEG in time

Raw EEG vs Filtered EEG The same 1-second recording for subject s1 at Oz (central electrode in the visual cortex, back of the head) is illustrated. left is the raw EEG as recorded and in the right is the filtered EEG on the [8, 64] Hz frequency band. we see less noise and normalized amplitude values in a natural EEG range.

elec = 6  # Oz channel

x_label = "Time (ms)"
y_label = "Voltage (uV)"
# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# Plot data on the first subplot
ax1.plot(time, X_raw[:, elec, 0, 0], "r-")
ax1.set_xlabel(x_label)
ax1.set_ylabel(y_label)
ax1.set_title("Raw EEG : 1 second at Oz ")

# Plot data on the second subplot
ax2.plot(time, X[0, :, elec, 0], "b-")
ax2.set_xlabel(x_label)
ax2.set_ylabel(y_label)
ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")

# Adjust spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

png


EEG frequency representation

Using the welch method, we visualize the frequency power for a well performing subject for the entire 4 seconds EEG recording at Oz electrode for each stimuli. the red peaks indicate the stimuli fundamental frequency and the 2nd harmonics (double the fundamental frequency). we see clear peaks showing the high responses from that subject which means that this subject is a good candidate for SSVEP BCI control. In many cases the peaks are weak or absent, meaning that subject do not achieve the task correctly.

eeg_frequency

Create Layers and model

Create Layers in a cross-framework custom component fashion. In the SSVEPFormer, the data is first transformed to the frequency domain through Fast-Fourier transform (FFT), to construct a complex spectrum presentation consisting of the concatenation of frequency and phase information in a fixed frequency band. To keep the model in an end-to-end format, we implement the complex spectrum transformation as non-trainable layer.

model The SSVEPFormer unlike the Transformer architecture does not contain positional encoding/embedding layers which replaced a channel combination block that has a layer of Conv1D layer of 1 kernel size with double input channels (double the count of electrodes) number of filters, and LayerNorm, Gelu activation and dropout. Another difference with Transformers is the absence of multi-head attention layers with attention mechanism. The model encoder contains two identical and successive blocks. Each block has two sub-blocks of CNN module and MLP module. the CNN module consists of a LayerNorm, Conv1D with the same number of filters as channel combination, LayerNorm, Gelu, Dropout and an residual connection. The MLP module consists of a LayerNorm, Dense layer, Gelu, droput and residual connection. the Dense layer is applied on each channel separately. The last block of the model is MLP head with Flatten layer, Dropout, Dense, LayerNorm, Gelu, Dropout and Dense layer with softmax acitvation. All trainable weights are initialized by a normal distribution with 0 mean and 0.01 standard deviation as state in the original paper.

class ComplexSpectrum(keras.layers.Layer):
    def __init__(self, nfft=512, fft_start=8, fft_end=64):
        super().__init__()
        self.nfft = nfft
        self.fft_start = fft_start
        self.fft_end = fft_end

    def call(self, x):
        samples = x.shape[-1]
        x = keras.ops.rfft(x, fft_length=self.nfft)
        real = x[0] / samples
        imag = x[1] / samples
        real = real[:, :, self.fft_start : self.fft_end]
        imag = imag[:, :, self.fft_start : self.fft_end]
        x = keras.ops.concatenate((real, imag), axis=-1)
        return x


class ChannelComb(keras.layers.Layer):
    def __init__(self, n_channels, drop_rate=0.5):
        super().__init__()
        self.conv = layers.Conv1D(
            2 * n_channels,
            1,
            padding="same",
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.normalization = layers.LayerNormalization()
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)

    def call(self, x):
        x = self.conv(x)
        x = self.normalization(x)
        x = self.activation(x)
        x = self.drop(x)
        return x


class ConvAttention(keras.layers.Layer):
    def __init__(self, n_channels, drop_rate=0.5):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.conv = layers.Conv1D(
            2 * n_channels,
            31,
            padding="same",
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)

    def call(self, x):
        input = x
        x = self.norm(x)
        x = self.conv(x)
        x = self.activation(x)
        x = self.drop(x)
        x = x + input
        return x


class ChannelMLP(keras.layers.Layer):
    def __init__(self, n_features, drop_rate=0.5):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.mlp = layers.Dense(
            2 * n_features,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)
        self.cat = layers.Concatenate(axis=1)

    def call(self, x):
        input = x
        channels = x.shape[1]  # x shape : NCF
        x = self.norm(x)
        output_channels = []
        for i in range(channels):
            c = self.mlp(x[:, :, i])
            c = layers.Reshape([1, -1])(c)
            output_channels.append(c)
        x = self.cat(output_channels)
        x = self.activation(x)
        x = self.drop(x)
        x = x + input
        return x


class Encoder(keras.layers.Layer):
    def __init__(self, n_channels, n_features, drop_rate=0.5):
        super().__init__()
        self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)
        self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)
        self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)
        self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)

    def call(self, x):
        x = self.attention1(x)
        x = self.mlp1(x)
        x = self.attention2(x)
        x = self.mlp2(x)
        return x


class MlpHead(keras.layers.Layer):
    def __init__(self, n_classes, drop_rate=0.5):
        super().__init__()
        self.flatten = layers.Flatten()
        self.drop = layers.Dropout(drop_rate)
        self.linear1 = layers.Dense(
            6 * n_classes,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.norm = layers.LayerNormalization()
        self.activation = layers.Activation(activation="gelu")
        self.drop2 = layers.Dropout(drop_rate)
        self.linear2 = layers.Dense(
            n_classes,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )

    def call(self, x):
        x = self.flatten(x)
        x = self.drop(x)
        x = self.linear1(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.drop2(x)
        x = self.linear2(x)
        return x

Create a sequential model with the layers above

def create_ssvepformer(
    input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate
):
    nfft = round(fs / resolution)
    fft_start = int(fq_band[0] / resolution)
    fft_end = int(fq_band[1] / resolution) + 1
    n_features = fft_end - fft_start

    model = keras.Sequential(
        [
            keras.Input(shape=input_shape),
            ComplexSpectrum(nfft, fft_start, fft_end),
            ChannelComb(n_channels=n_channels, drop_rate=drop_rate),
            Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
            Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
            MlpHead(n_classes=n_classes, drop_rate=drop_rate),
            layers.Activation(activation="softmax"),
        ]
    )

    return model

Evaluation

# Training settings same as the original paper
BATCH_SIZE = 128
EPOCHS = 100
LR = 0.001  # learning rate
WD = 0.001  # weight decay
MOMENTUM = 0.9
DROP_RATE = 0.5

resolution = 0.25

From the entire dataset we select folds for each subject evaluation. construct a tf dataset object for train and testing data and create the model and launch the training using SGD optimizer.

def concatenate_subjects(x, y, fold):
    X = np.concatenate([x[idx] for idx in fold], axis=-1)
    Y = np.concatenate([y[idx] for idx in fold], axis=-1)
    X = X.transpose((2, 1, 0))  # trials x channels x samples
    return X, Y - 1  # transform labels to values from 0...11


def evaluate_subject(
    x_train,
    y_train,
    x_val,
    y_val,
    input_shape,
    fs=256,
    resolution=0.25,
    band=[8, 64],
    channels=8,
    n_classes=12,
    drop_rate=DROP_RATE,
):

    train_dataset = (
        tf.data.Dataset.from_tensor_slices((x_train, y_train))
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    test_dataset = (
        tf.data.Dataset.from_tensor_slices((x_val, y_val))
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    model = create_ssvepformer(
        input_shape, fs, resolution, band, channels, n_classes, drop_rate
    )
    sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=sgd,
        metrics=["accuracy"],
        jit_compile=True,
    )

    history = model.fit(
        train_dataset,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=test_dataset,
        verbose=0,
    )
    loss, acc = model.evaluate(test_dataset)
    return acc * 100

Run evaluation

channels = X.shape[2]
samples = X.shape[1]
input_shape = (channels, samples)
n_classes = 12

model = create_ssvepformer(
    input_shape, fs, resolution, band, channels, n_classes, DROP_RATE
)
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                          Output Shape                         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ complex_spectrum (ComplexSpectrum)   │ (None, 8, 450)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ channel_comb (ChannelComb)           │ (None, 16, 450)             │           1,044 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ encoder (Encoder)                    │ (None, 16, 450)             │          34,804 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ encoder_1 (Encoder)                  │ (None, 16, 450)             │          34,804 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ mlp_head (MlpHead)                   │ (None, 12)                  │         519,492 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation_10 (Activation)           │ (None, 12)                  │               0 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 590,144 (2.25 MB)
 Trainable params: 590,144 (2.25 MB)
 Non-trainable params: 0 (0.00 B)

Evaluation on all subjects following a leave-one-subject out data repartition scheme

accs = np.zeros(10)

for subject in range(10):
    print(f"Testing subject: {subject+ 1}")

    # create train / test folds
    folds = np.delete(np.arange(10), subject)
    train_index = folds
    test_index = [subject]

    # create data split for each subject
    x_train, y_train = concatenate_subjects(X, Y, train_index)
    x_val, y_val = concatenate_subjects(X, Y, test_index)

    # train and evaluate a fold and compute the time it takes
    acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)

    accs[subject] = acc

print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")
Testing subject: 1

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1737801392.665434    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156241    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156492    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160002    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160279    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160421    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168480    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168667    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168908    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355

1/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5938 - loss: 1.2483


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 39ms/step - accuracy: 0.5868 - loss: 1.3010

Testing subject: 2

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5469 - loss: 1.5270


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5119 - loss: 1.5974

Testing subject: 3

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7266 - loss: 0.8867


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6903 - loss: 1.0117

Testing subject: 4

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9688 - loss: 0.1574


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9637 - loss: 0.1487

Testing subject: 5

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9375 - loss: 0.2184


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9458 - loss: 0.1887

Testing subject: 6

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9688 - loss: 0.1117


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9674 - loss: 0.1018

Testing subject: 7

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9141 - loss: 0.2639


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9158 - loss: 0.2592

Testing subject: 8

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 9ms/step - accuracy: 0.9922 - loss: 0.0562


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - accuracy: 0.9937 - loss: 0.0518

Testing subject: 9

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9844 - loss: 0.0669


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9837 - loss: 0.0701

Testing subject: 10

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9219 - loss: 0.3438


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - accuracy: 0.8999 - loss: 0.4543

Accuracy Across Subjects: 84.11111384630203 % std: 17.575586372993953

and that's it! we see how some subjects with no data on the training set still can achieve almost a 100% correct commands and others show poor performance around 50%. In the original paper using PyTorch the average accuracy was 84.04% with 17.37 std. we reached the same values knowing the stochastic nature of deep learning.

Visualizations

Training and inference times comparison between the different backends (Jax, Tensorflow and PyTorch) on the three GPUs available with Colab Free/Pro/Pro+: T4, L4, A100.


Training Time

training_time

Inference Time

inference_time

the Jax backend was the best on training and inference in all the GPUs, the PyTorch was exremely slow due to the jit compilation option being disable because of the complex data type calculated by FFT which is not supported by the PyTorch jit compiler.

Acknowledgment

I thank Chris Perry X @GoogleColab for supporting this work with GPU compute.

References

[1] Chen, J. et al. (2023) ‘A transformer-based deep neural network model for SSVEP classification’, Neural Networks, 164, pp. 521–534. Available at: https://doi.org/10.1016/j.neunet.2023.04.045.

[2] Nakanishi, M. et al. (2015) ‘A Comparison Study of Canonical Correlation Analysis Based Methods for Detecting Steady-State Visual Evoked Potentials’, Plos One, 10(10), p. e0140703. Available at: https://doi.org/10.1371/journal.pone.0140703