Block 2. Data Augmentation and Transfer Learning

Open In Colab

Goal. Two strategies for the few-labels problem, both from Chapter 19:

  • Augmentation (19.1): invent new trials by perturbing the ones you have.
  • Transfer learning (19.2): train on other people, then nudge to the target.

We watch each one work, and then watch them stack.

Time. About 75 minutes. The first pretraining cell takes 2 to 3 minutes on Colab T4. Start it and keep reading.


0. Setup

%%capture
!pip install -q moabb==1.1.0 mne==1.7.1 braindecode==0.8.1 skorch==1.0.0
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
import torch

from moabb.datasets import PhysionetMI
from moabb.paradigms import LeftRightImagery

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score

from scipy.signal import welch

from braindecode.models import EEGNetv4
from braindecode import EEGClassifier
from skorch.callbacks import LRScheduler

mne.set_log_level("WARNING")
np.random.seed(42)
torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Setup complete. Using device: {device}")

1. Reload the data

If you ran block 1 in this same Colab session, the dataset is cached and the next cell is fast.

dataset = PhysionetMI()
paradigm = LeftRightImagery()
subjects = list(range(1, 11))
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=subjects)

# braindecode needs integer labels
le = LabelEncoder()
y_int = le.fit_transform(y)

print(f"X shape: {X.shape}  (n_trials, n_channels, n_times)")
print(f"Classes: {le.classes_}")

# Cast to float32 once, for PyTorch
X = X.astype(np.float32)
sfreq = 160
n_chans, n_times = X.shape[1], X.shape[2]

We will also need bandpower features for the first experiment.

def bandpower_features(X, sfreq=160, bands=(("alpha", 8, 13), ("beta", 13, 30))):
    n_trials, n_channels, n_t = X.shape
    n_bands = len(bands)
    feats = np.zeros((n_trials, n_channels * n_bands), dtype=np.float32)
    for t in range(n_trials):
        for c in range(n_channels):
            f, psd = welch(X[t, c], fs=sfreq, nperseg=min(256, n_t))
            for b, (_, lo, hi) in enumerate(bands):
                m = (f >= lo) & (f <= hi)
                feats[t, c * n_bands + b] = np.log(psd[m].mean() + 1e-12)
    return feats

features = bandpower_features(X, sfreq=sfreq)
print(f"Bandpower features: {features.shape}")

2. Data augmentation: four EEG-specific tricks

Augmentation is a way of telling the model “these transformations should not change the label”. You apply them randomly during training, and the model learns to be invariant.

For EEG, four cheap and standard augmentations:

def time_shift(x, max_shift=20, rng=None):
    """Roll the trial along time axis by a random small amount."""
    rng = rng or np.random
    shift = rng.randint(-max_shift, max_shift + 1)
    return np.roll(x, shift, axis=-1)

def channel_dropout(x, p=0.1, rng=None):
    """Zero out each channel independently with probability p."""
    rng = rng or np.random
    keep = rng.rand(x.shape[0]) > p
    out = x.copy()
    out[~keep] = 0
    return out

def add_gaussian_noise(x, sigma=0.1, rng=None):
    """Add white noise scaled by sigma * trial std."""
    rng = rng or np.random
    return x + rng.randn(*x.shape).astype(x.dtype) * x.std() * sigma

def mixup_pair(x1, x2, alpha=0.2, rng=None):
    """Linear blend of two trials of the same class."""
    rng = rng or np.random
    lam = rng.beta(alpha, alpha) if alpha > 0 else 1.0
    return (lam * x1 + (1 - lam) * x2).astype(x1.dtype)

Why these four?

  • Time shift says: the exact onset of the imagery is not perfectly marked. A trial shifted by 100 ms is still the same trial.
  • Channel dropout says: if one electrode falls off, the model should still work. Like sensor failure in a real BCI.
  • Gaussian noise says: small trial-to-trial differences are not meaningful. The model should not memorize them.
  • Mixup says: a half-and-half blend of two left-hand trials is also a left-hand trial. This one is more aggressive and more controversial.

Visualize what they do to one trial.

fig, axes = plt.subplots(2, 2, figsize=(11, 6), sharex=True, sharey=True)
trial = X[0, 4]  # subject 1, channel C3
times = np.arange(len(trial)) / sfreq

axes[0, 0].plot(times, trial, alpha=0.6, label="original")
axes[0, 0].plot(times, time_shift(trial[None, :])[0], alpha=0.8, label="time-shifted")
axes[0, 0].set_title("Time shift")
axes[0, 0].legend()

axes[0, 1].plot(times, trial, alpha=0.6, label="original")
axes[0, 1].plot(times, add_gaussian_noise(trial[None, :], sigma=0.3)[0], alpha=0.8,
                label="+ noise")
axes[0, 1].set_title("Gaussian noise (sigma=0.3 for visibility)")
axes[0, 1].legend()

# For channel dropout we need a 2D trial; show two channels
two_ch = X[0, [4, 5]]
two_ch_dropped = channel_dropout(two_ch, p=0.5)
axes[1, 0].plot(times, two_ch[0], alpha=0.6, label="C3 original")
axes[1, 0].plot(times, two_ch_dropped[0], alpha=0.8, label="C3 after dropout")
axes[1, 0].set_title("Channel dropout (p=0.5 for visibility)")
axes[1, 0].legend()

# Mixup
trial2 = X[1, 4]
mixed = mixup_pair(trial[None, :], trial2[None, :], alpha=2.0)[0]
axes[1, 1].plot(times, trial, alpha=0.4, label="trial 1")
axes[1, 1].plot(times, trial2, alpha=0.4, label="trial 2")
axes[1, 1].plot(times, mixed, alpha=0.9, label="mixed", color="C3")
axes[1, 1].set_title("Mixup (alpha=2.0 for visibility)")
axes[1, 1].legend()

for ax in axes.flat:
    ax.set_xlabel("Time (s)")
plt.tight_layout()
plt.show()

3. Does augmentation help when labels are scarce?

The right place to test augmentation is on small training sets. The prediction: with very few labels, augmentation should help. With many labels, the gain should disappear (you do not need synthetic data when you have real data).

We compare two pipelines on subject 1, varying the number of training trials per class:

  1. Plain: train on the original few trials.
  2. Augmented: train on the original few trials plus 3 augmented copies.

Use the simple bandpower + logistic regression pipeline for speed. The gain pattern is visible here; it would be larger for a deep model.

def augment_trials(X_in, y_in, n_copies=3, rng=None):
    """Stack original X with n_copies of augmented versions."""
    rng = rng or np.random.RandomState(0)
    X_list = [X_in]
    y_list = [y_in]
    for _ in range(n_copies):
        X_aug = np.empty_like(X_in)
        for i, x in enumerate(X_in):
            x = time_shift(x, rng=rng)
            x = channel_dropout(x, p=0.05, rng=rng)
            x = add_gaussian_noise(x, sigma=0.05, rng=rng)
            X_aug[i] = x
        X_list.append(X_aug)
        y_list.append(y_in)
    return np.concatenate(X_list), np.concatenate(y_list)


def few_label_score(X_pool, y_pool, n_per_class, augment=False, n_seeds=10):
    accs = []
    classes = np.unique(y_pool)
    for seed in range(n_seeds):
        rng = np.random.RandomState(seed)
        chosen = []
        for c in classes:
            idx = np.where(y_pool == c)[0]
            chosen.extend(rng.choice(idx, size=n_per_class, replace=False))
        chosen = np.array(chosen)
        X_train_raw, y_train = X_pool[chosen], y_pool[chosen]
        if augment:
            X_train_aug, y_train_aug = augment_trials(X_train_raw, y_train,
                                                      n_copies=3, rng=rng)
            X_train = bandpower_features(X_train_aug, sfreq=sfreq)
            y_train_use = y_train_aug
        else:
            X_train = bandpower_features(X_train_raw, sfreq=sfreq)
            y_train_use = y_train
        # Evaluate on the rest of subject 1
        rest = np.setdiff1d(np.arange(len(y_pool)), chosen)
        X_test = bandpower_features(X_pool[rest], sfreq=sfreq)
        y_test = y_pool[rest]
        pipe = make_pipeline(StandardScaler(),
                             LogisticRegression(max_iter=1000, C=1.0))
        pipe.fit(X_train, y_train_use)
        accs.append(pipe.score(X_test, y_test))
    return np.mean(accs), np.std(accs)


# Subject 1 raw trials (so we can augment in time domain)
mask_subj1 = (metadata["subject"] == 1).values
X_subj1_raw = X[mask_subj1]
y_subj1 = y[mask_subj1]

n_list = [5, 10, 20]
results_aug = []
for n in n_list:
    plain = few_label_score(X_subj1_raw, y_subj1, n, augment=False)
    augm  = few_label_score(X_subj1_raw, y_subj1, n, augment=True)
    results_aug.append((n, plain, augm))
    print(f"n={n}: plain={plain[0]:.3f}+-{plain[1]:.3f}  "
          f"augmented={augm[0]:.3f}+-{augm[1]:.3f}")
ns = [r[0] for r in results_aug]
plain_m = [r[1][0] for r in results_aug]
plain_s = [r[1][1] for r in results_aug]
aug_m = [r[2][0] for r in results_aug]
aug_s = [r[2][1] for r in results_aug]

fig, ax = plt.subplots(figsize=(7, 4))
ax.errorbar(ns, plain_m, yerr=plain_s, marker="o", label="plain", capsize=4)
ax.errorbar(ns, aug_m,   yerr=aug_s,   marker="s", label="+ augmentation", capsize=4)
ax.axhline(0.5, ls="--", color="k", alpha=0.5, label="chance")
ax.set_xlabel("Trials per class")
ax.set_ylabel("Accuracy on rest of subject 1")
ax.set_title("19.1 in action: augmentation helps when labels are scarce")
ax.set_ylim(0.4, 0.9)
ax.legend()
plt.tight_layout()
plt.show()

You should see the augmented line sit a few points above the plain line, especially at n=5 and n=10. The gap shrinks as labels grow.

That is the canonical signature of a working augmentation. The size of the gap depends on the model: it would be larger for EEGNet, which is much more parameter-hungry than logistic regression.


4. Switching to a deep model: EEGNet

EEGNet (Lawhern et al., 2018) is a small convolutional network designed for EEG. About 2,000 parameters. Trains in seconds.

braindecode wraps it in a scikit-learn-compatible interface, so we can use fit and score like we did with logistic regression.

def make_eegnet(n_chans, n_times, lr=0.001, epochs=20, batch=32):
    """Factory for a fresh EEGNet classifier."""
    return EEGClassifier(
        EEGNetv4,
        module__n_chans=n_chans,
        module__n_outputs=2,
        module__n_times=n_times,
        optimizer=torch.optim.AdamW,
        optimizer__lr=lr,
        optimizer__weight_decay=0.01,
        train_split=None,
        batch_size=batch,
        max_epochs=epochs,
        device=device,
        verbose=0,
    )

# Sanity check: from-scratch on subject 1 with all labels
clf = make_eegnet(n_chans, n_times, epochs=20)
clf.fit(X[mask_subj1], y_int[mask_subj1])
print(f"EEGNet within-subject (all labels): "
      f"{clf.score(X[mask_subj1], y_int[mask_subj1]):.3f}")

Expect something in the 0.75 to 0.90 range. Note that this is overfitted (trained and evaluated on the same data), so the number is generous. We will get honest numbers in the next sections.


5. Three transfer regimes on a held-out target

We hold out subject 1 as the target. The other 9 subjects are the “source pool”.

Three regimes to compare, each with N=20 target trials per class for fine-tuning or training:

Regime Pretrained on Fine-tuned on
From scratch nothing 20 trials of S1
Zero-shot transfer 9 source subjects nothing
Fine-tuned transfer 9 source subjects 20 trials of S1
target_subject = 1
source_subjects = [s for s in subjects if s != target_subject]

mask_target = (metadata["subject"] == target_subject).values
mask_source = ~mask_target

X_target_all, y_target_all = X[mask_target], y_int[mask_target]
X_source, y_source = X[mask_source], y_int[mask_source]

print(f"Target (S{target_subject}): {len(y_target_all)} trials")
print(f"Source (other 9):           {len(y_source)} trials")

5a. Pretrain on source

This cell takes about 2-3 minutes on T4. Start it, then keep reading.

clf_pretrained = make_eegnet(n_chans, n_times, epochs=30)
clf_pretrained.fit(X_source, y_source)
print(f"Source training accuracy: {clf_pretrained.score(X_source, y_source):.3f}")

While that runs: a small thought experiment. How much of what EEGNet learns from 9 brains do you expect to transfer to a 10th brain? The answer is “some, but not enough”. Otherwise BCI research would be over.

5b. Pick a small fine-tune set on the target

def sample_target_trials(X_t, y_t, n_per_class, seed=0):
    rng = np.random.RandomState(seed)
    chosen = []
    for c in np.unique(y_t):
        idx = np.where(y_t == c)[0]
        chosen.extend(rng.choice(idx, size=n_per_class, replace=False))
    chosen = np.array(chosen)
    rest = np.setdiff1d(np.arange(len(y_t)), chosen)
    return X_t[chosen], y_t[chosen], X_t[rest], y_t[rest]

n_target = 20
X_t_train, y_t_train, X_t_test, y_t_test = sample_target_trials(
    X_target_all, y_target_all, n_per_class=n_target, seed=0
)
print(f"Target train: {len(y_t_train)} trials")
print(f"Target test:  {len(y_t_test)} trials")

5c. From scratch on target

clf_scratch = make_eegnet(n_chans, n_times, epochs=20)
clf_scratch.fit(X_t_train, y_t_train)
acc_scratch = clf_scratch.score(X_t_test, y_t_test)
print(f"From scratch on target ({2*n_target} trials): {acc_scratch:.3f}")

5d. Zero-shot transfer

acc_zeroshot = clf_pretrained.score(X_t_test, y_t_test)
print(f"Zero-shot transfer (pretrained, no fine-tuning): {acc_zeroshot:.3f}")

5e. Fine-tuned transfer

We take the pretrained model and continue training on the small target set.

import copy
clf_finetune = copy.deepcopy(clf_pretrained)
clf_finetune.set_params(max_epochs=15, optimizer__lr=0.0003)
clf_finetune.partial_fit(X_t_train, y_t_train)
acc_finetune = clf_finetune.score(X_t_test, y_t_test)
print(f"Fine-tuned transfer: {acc_finetune:.3f}")

5f. Fine-tuned transfer + augmentation

X_t_train_aug, y_t_train_aug = augment_trials(
    X_t_train, y_t_train, n_copies=3, rng=np.random.RandomState(0)
)

clf_finetune_aug = copy.deepcopy(clf_pretrained)
clf_finetune_aug.set_params(max_epochs=15, optimizer__lr=0.0003)
clf_finetune_aug.partial_fit(X_t_train_aug.astype(np.float32), y_t_train_aug)
acc_finetune_aug = clf_finetune_aug.score(X_t_test, y_t_test)
print(f"Fine-tuned + augmented: {acc_finetune_aug:.3f}")

6. The picture

labels = ["from scratch\n(20 trials)",
          "zero-shot\ntransfer",
          "fine-tuned\ntransfer",
          "fine-tuned\n+ augmentation"]
values = [acc_scratch, acc_zeroshot, acc_finetune, acc_finetune_aug]

fig, ax = plt.subplots(figsize=(8, 4.5))
bars = ax.bar(labels, values, color=["#888", "#4a90d9", "#2c7a3e", "#a14545"])
ax.axhline(0.5, ls="--", color="k", alpha=0.5, label="chance")
ax.set_ylim(0.4, 1.0)
ax.set_ylabel("Accuracy on held-out trials of S1")
ax.set_title(f"Block 2 picture: target = S{target_subject}, n_train = {2*n_target}")
ax.legend()
for bar, v in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}",
            ha="center", fontsize=11)
plt.tight_layout()
plt.show()

The story this plot tells:

  • From scratch is barely above chance: 20 trials are not enough for a randomly initialized EEGNet to learn anything reliable.
  • Zero-shot transfer is a little better, but not by much. Naive transfer from other brains does not solve the problem.
  • Fine-tuning bridges them: we take the partial knowledge from other brains and adapt it cheaply to the target.
  • Augmentation on top adds another small gain. They stack.

This is the cumulative argument of Murphy 19.1 + 19.2 in one bar chart.


7. Discussion

  1. The zero-shot bar might not be much above chance. Why is “pretraining on 9 brains” not enough by itself for motor imagery, when “pretraining on ImageNet” is often enough by itself for many vision tasks?

  2. Mixup blends two trials and gives them a blended label. Does that make biological sense for EEG? When does it work and when does it produce nonsense?

  3. If you had to choose one fix to keep: would you keep the augmentation or the transfer learning? Why?


8. What just happened

We took the cross-subject failure from block 1 and partly fixed it.

The fix was not a single deep learning trick. It was the combination of:

  • generating synthetic versions of the few labels we had (19.1), and
  • starting from a model that had already seen other brains (19.2).

In the next block we will go one step further: instead of pretraining a model from scratch on other brains, we will use a frozen pretrained representation and only learn the final classifier on top. We will also let the unlabeled data start pulling weight, which is the topic of Murphy 19.3.