Help from Other Datasets and Tasks
Time. 13:00 - 15:00 (2h).
Goal. Notebook 1 squeezed one subject’s data. Now look outward: other subjects ran the same paradigm, and other datasets carry usable structure. Three standard fixes that borrow from external data:
- Borrowed features: a pretrained representation does most of the work, no fine-tuning.
- Transfer learning (PML 19.2): pretrain on other people, fine-tune on the target.
- Few-shot calibration (PML 19.6): given pretraining, how few target trials suffice?
0. Setup
Same moabb + braindecode + seaborn as notebook 1, plus three new packages:
- pyriemann: Riemannian geometry on EEG covariance matrices. Used in §4 for the frozen-features comparison.
- sentence-transformers (optional, §7 only): Off-the-shelf pretrained text embeddings. Drop it if you skip §7.
- umap-learn (optional, §7 only): Nonlinear dimensionality reduction for visualizing text embeddings. Drop it if you skip §7.
%%capture
!pip install -q moabb==1.1.0 braindecode==0.8.1 pyriemann==0.6 seaborn \
sentence-transformers==3.0.1 umap-learn==0.5.6import os, copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mne
import torch
from moabb.datasets import PhysionetMI
from moabb.paradigms import LeftRightImagery
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score
from scipy.signal import welch
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from braindecode.models import EEGNetv4
from braindecode import EEGClassifier
mne.set_log_level("WARNING")
np.random.seed(42); torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Setup complete. Device: {device}")#@title Helpers (run once, then collapse) { display-mode: "form" }
BANDS = (("alpha", 8, 13), ("beta", 13, 30))
def extract_bandpower(Xin, sfreq=160, bands=BANDS):
f, psd = welch(Xin, fs=sfreq, nperseg=min(256, Xin.shape[-1]), axis=-1)
powers = np.stack([psd[..., (f >= lo) & (f <= hi)].mean(-1) for _, lo, hi in bands], axis=-1)
return np.log(powers + 1e-12).reshape(len(Xin), -1).astype(np.float32)
def balanced_split(X_t, y_t, n_per_class, seed=0):
rng = np.random.RandomState(seed)
chosen = np.concatenate([
rng.choice(np.where(y_t == c)[0], size=n_per_class, replace=False)
for c in np.unique(y_t)
])
rest = np.setdiff1d(np.arange(len(y_t)), chosen)
return X_t[chosen], y_t[chosen], X_t[rest], y_t[rest]
def lr_pipe(C=1.0):
return make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000, C=C))
def few_label_curve(features, y_arr, n_list=(5, 10, 20, 40), n_seeds=10):
rows = []
for n in n_list:
accs = []
for seed in range(n_seeds):
Xtr, ytr, Xte, yte = balanced_split(features, y_arr, n, seed=seed)
accs.append(lr_pipe().fit(Xtr, ytr).score(Xte, yte))
rows.append({"n_per_class": n, "mean": np.mean(accs), "std": np.std(accs)})
return pd.DataFrame(rows)
def build_eegnet(n_chans, n_times, epochs=20, lr=0.001, batch=32):
return EEGClassifier(
EEGNetv4,
module__n_chans=n_chans, module__n_outputs=2, module__n_times=n_times,
optimizer=torch.optim.AdamW, optimizer__lr=lr, optimizer__weight_decay=0.01,
train_split=None, batch_size=batch, max_epochs=epochs, device=device, verbose=0,
)
def time_shift(x, max_shift=20, rng=None):
rng = rng or np.random
return np.roll(x, rng.randint(-max_shift, max_shift + 1), axis=-1)
def channel_dropout(x, p=0.05, rng=None):
rng = rng or np.random
out = x.copy()
out[rng.rand(x.shape[0]) <= p] = 0
return out
def add_gaussian_noise(x, sigma=0.05, rng=None):
rng = rng or np.random
return x + rng.randn(*x.shape).astype(x.dtype) * x.std() * sigma
def augment_trials(X_in, y_in, n_copies=3, rng=None):
rng = rng or np.random.RandomState(0)
X_list, y_list = [X_in], [y_in]
for _ in range(n_copies):
X_aug = np.stack([
add_gaussian_noise(channel_dropout(time_shift(x, rng=rng), rng=rng), rng=rng)
for x in X_in
])
X_list.append(X_aug); y_list.append(y_in)
return np.concatenate(X_list), np.concatenate(y_list)
sns.set_theme(context="notebook", style="whitegrid")1. Data
dataset = PhysionetMI()
paradigm = LeftRightImagery()
subjects = list(range(1, 11))
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=subjects)
X = X.astype(np.float32)
y_int = LabelEncoder().fit_transform(y)
sfreq = 160
n_chans, n_times = X.shape[1], X.shape[2]
target_subject = 1
target_mask = (metadata["subject"] == target_subject).values
source_mask = ~target_mask
X_target_all, y_target_all = X[target_mask], y_int[target_mask]
X_source, y_source = X[source_mask], y_int[source_mask]
print(f"Target S{target_subject}: {len(y_target_all)} trials | Source (other 9): {len(y_source)} trials")2. Failure: Not the Right Person
Notebook 1’s model failed because labels were scarce. Here it fails for a different reason: train on one subject, test on another, with no calibration. Sweep all 10 subjects in both directions.
features_bp = extract_bandpower(X)
cross_rows = []
for s_train in subjects:
pipe = lr_pipe().fit(features_bp[(metadata["subject"] == s_train).values],
y_int[(metadata["subject"] == s_train).values])
for s_test in subjects:
if s_test == s_train:
continue
m_test = (metadata["subject"] == s_test).values
cross_rows.append({"source": s_train, "target": s_test,
"accuracy": pipe.score(features_bp[m_test], y_int[m_test])})
results_cross = pd.DataFrame(cross_rows)
print(f"Mean cross-subject accuracy: {results_cross['accuracy'].mean():.3f} "
f"(min {results_cross['accuracy'].min():.3f}, max {results_cross['accuracy'].max():.3f})")Compare within-subject (everything one subject has) against cross-subject (no calibration).
within = pd.DataFrame([
{"subject": s,
"accuracy": cross_val_score(lr_pipe(),
features_bp[(metadata["subject"] == s).values],
y_int[(metadata["subject"] == s).values],
cv=5).mean(),
"regime": "within (all labels)"}
for s in subjects
])
cross_by_source = (results_cross.groupby("source")["accuracy"].mean()
.rename_axis("subject").reset_index()
.assign(regime="cross (zero calibration)"))
diagnosis = pd.concat([within, cross_by_source], ignore_index=True)
sns.barplot(data=diagnosis, x="subject", y="accuracy", hue="regime")
plt.axhline(0.5, ls="--", c="k", alpha=0.5)
plt.show()Within-subject bars sit between 0.65 and 0.85; cross-subject bars hover near 0.55. The gap is what borrowing from other subjects has to close.
3. Pretrain a Model
Train a small CNN on the source pool and save the checkpoint. About 2–3 minutes on T4. The next notebook reloads this file.
PRETRAINED_PATH = "/content/eegnet_pretrained.pt"
clf_pretrained = build_eegnet(n_chans, n_times, epochs=30)
if os.path.exists(PRETRAINED_PATH):
print("Loading pretrained EEGNet from disk")
clf_pretrained.initialize()
clf_pretrained.load_params(f_params=PRETRAINED_PATH)
else:
print("No checkpoint found, training (2-3 minutes)")
clf_pretrained.fit(X_source, y_source)
clf_pretrained.save_params(f_params=PRETRAINED_PATH)
print(f"Saved checkpoint to {PRETRAINED_PATH}")
print(f"Source training accuracy: {clf_pretrained.score(X_source, y_source):.3f}")How much of what the model learned from 9 brains transfers to a 10th? Some, but not enough.
4. Fix 1: Borrowed Features
The right representation does most of the work. Run three feature extractors on S1’s trials, train a small LR on top of each, sweep the label budget.
| Features | Era | Cost |
|---|---|---|
| Bandpower | 1990s | Trivial |
| Riemannian tangent space | 2010s | Cheap |
| Pretrained EEGNet penultimate | 2020s | One forward pass |
bp_s1 = extract_bandpower(X_target_all)
ri_s1 = TangentSpace().fit_transform(
Covariances(estimator="oas").fit_transform(X_target_all)
).astype(np.float32)
print(f"Bandpower: {bp_s1.shape} | Riemannian: {ri_s1.shape}")Pull the EEGNet penultimate via a forward hook on the final layer’s input.
backbone = clf_pretrained.module_.eval().to(device)
penult_buffer = []
handle = backbone.final_layer.register_forward_hook(
lambda m, inp, out: penult_buffer.append(inp[0].detach().cpu().numpy()))
with torch.no_grad():
for i in range(0, len(X_target_all), 64):
backbone(torch.from_numpy(X_target_all[i:i+64]).to(device))
handle.remove()
en_s1 = np.concatenate([f.reshape(f.shape[0], -1) for f in penult_buffer]).astype(np.float32)
print(f"EEGNet penultimate: {en_s1.shape}")Few-labels sweep on all three.
curves = pd.concat({
"bandpower (1990s)": few_label_curve(bp_s1, y_target_all),
"Riemannian (2010s)": few_label_curve(ri_s1, y_target_all),
"EEGNet penultimate (borrowed)": few_label_curve(en_s1, y_target_all),
}, names=["features"]).reset_index(level=0)
sns.lineplot(data=curves, x="n_per_class", y="mean", hue="features", marker="o")
plt.axhline(0.5, ls="--", c="k", alpha=0.5)
plt.show()The pretrained penultimate sits above both hand-engineered baselines at every budget. Borrowing pays even when the network never touches the target: the representation alone is the lift.
5. Fix 2: Transfer Learning
The cleanest decomposition of “what does pretraining buy you?” needs four conditions on the same target test set. Hold out 20 target trials per class for fine-tuning; score on the rest.
n_target = 20
X_t_train, y_t_train, X_t_test, y_t_test = balanced_split(
X_target_all, y_target_all, n_per_class=n_target, seed=0
)
print(f"Target train: {len(y_t_train)} | Target test: {len(y_t_test)}")regimes = {}
# (i) from scratch on the target alone
scratch = build_eegnet(n_chans, n_times, epochs=20)
scratch.fit(X_t_train, y_t_train)
regimes["from scratch (20 trials)"] = scratch.score(X_t_test, y_t_test)
# (ii) zero-shot: pretrained model, no calibration
regimes["zero-shot transfer"] = clf_pretrained.score(X_t_test, y_t_test)
# (iii) fine-tune the pretrained model on the small target set
ft = copy.deepcopy(clf_pretrained)
ft.set_params(max_epochs=15, optimizer__lr=0.0003)
ft.partial_fit(X_t_train, y_t_train)
regimes["fine-tuned transfer"] = ft.score(X_t_test, y_t_test)
# (iv) fine-tune with augmented copies of the target set
X_aug, y_aug = augment_trials(X_t_train, y_t_train, n_copies=3,
rng=np.random.RandomState(0))
ft_aug = copy.deepcopy(clf_pretrained)
ft_aug.set_params(max_epochs=15, optimizer__lr=0.0003)
ft_aug.partial_fit(X_aug.astype(np.float32), y_aug)
regimes["fine-tuned + augmentation"] = ft_aug.score(X_t_test, y_t_test)
regimes_df = pd.Series(regimes, name="accuracy").rename_axis("regime").reset_index()
regimes_dfsns.barplot(data=regimes_df, x="regime", y="accuracy")
plt.axhline(0.5, ls="--", c="k", alpha=0.5)
plt.xticks(rotation=15)
plt.show()From-scratch is barely above chance: 20 trials cannot teach a randomly-initialized EEGNet anything reliable. Zero-shot is a little better, but not by much — naive transfer doesn’t solve the problem on its own. Fine-tuning bridges them. Augmentation on top adds another small gain. They stack.
6. Fix 3: Few-shot Calibration
Every new BCI user needs a calibration session. The question is: how short can it be?
Sweep N target trials per class from {1, 2, 5, 10, 20, 50}, fine-tune the pretrained EEGNet on each, score on the rest.
fewshot_rows = []
for n in [1, 2, 5, 10, 20, 50]:
if 2 * n > len(y_target_all):
continue
accs = []
for seed in range(5):
Xtr, ytr, Xte, yte = balanced_split(X_target_all, y_target_all, n, seed=seed)
clf = copy.deepcopy(clf_pretrained).set_params(max_epochs=15, optimizer__lr=0.0003)
accs.append(clf.partial_fit(Xtr, ytr).score(Xte, yte))
fewshot_rows.append({"n_per_class": n, "mean": np.mean(accs), "std": np.std(accs)})
results_fewshot = pd.DataFrame(fewshot_rows)
results_fewshotsns.lineplot(data=results_fewshot, x="n_per_class", y="mean", marker="o")
plt.axhline(0.5, ls="--", c="k", alpha=0.5)
plt.xscale("log")
plt.show()Steep at the bottom, saturating at the top. The shape tells you what “enough calibration” means for your application — a research demo lives at the steep part; a wheelchair controller does not.
Cog-sci framing. Lake et al. (2015) made the same argument with humans on Omniglot: people learn from one or a few examples because they bring strong priors. Pretraining on 9 other subjects is just the prior, with the parameters borrowed instead of philosophized.
7. Text Data (optional)
Optional. This section runs the “frozen pretrained encoder + tiny classifier” recipe on text instead of EEG. Skip it without losing anything. Requires
sentence-transformersandumap-learn.
The recipe is not EEG-specific. Same shape on text with nomic-embed-text-v1.5 (~137M params, open-weight, T4-friendly).
symptoms = {
"Major Depressive Disorder": [
"I feel sad or empty most of the day",
"I have lost interest in activities I used to enjoy",
"I am tired all the time even after resting",
"I have difficulty concentrating on simple tasks",
"I feel worthless or guilty without clear reason",
"I have trouble making everyday decisions",
"My appetite has changed noticeably",
"I sleep too much or cannot sleep at all",
"I think about death frequently",
"I move and speak more slowly than I used to",
],
"Generalized Anxiety Disorder": [
"I worry about everything even small things",
"I cannot control my worrying thoughts",
"I feel restless and on edge most days",
"I am tired despite not doing very much",
"I have difficulty concentrating because of worry",
"I am irritable with people around me",
"My muscles feel tense and sore for no reason",
"I cannot fall asleep because my mind is racing",
"I feel a constant sense of impending doom",
"I worry about my health all the time",
],
"ADHD": [
"I am easily distracted by sounds or movement",
"I have trouble finishing what I start",
"I lose important objects regularly",
"I forget appointments and obligations",
"I find it hard to sit still during meetings",
"I interrupt others when they are speaking",
"I make careless mistakes at work or school",
"I avoid tasks that require sustained mental effort",
"I fidget with my hands or feet constantly",
"I have trouble organizing my daily tasks",
],
"Insomnia Disorder": [
"I lie awake for hours before falling asleep",
"I wake up multiple times during the night",
"I wake up too early and cannot return to sleep",
"I feel unrefreshed in the morning despite sleeping",
"I worry about not being able to sleep",
"I am sleepy during the day from lack of sleep",
"My sleep is interrupted by physical discomfort",
"I dread going to bed because of my sleeplessness",
"I rely on substances to fall asleep",
"My total sleep time is much less than I need",
],
}
text_df = pd.DataFrame([(t, cat) for cat, items in symptoms.items() for t in items],
columns=["text", "label"])
print(f"{len(text_df)} sentences, {text_df['label'].nunique()} classes")from sentence_transformers import SentenceTransformer
emb_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
emb = emb_model.encode(text_df["text"].tolist(), show_progress_bar=False)
print(f"Embedding shape: {emb.shape}")import umap
emb_2d = umap.UMAP(n_neighbors=8, min_dist=0.3, random_state=42).fit_transform(emb)
plot_df = pd.DataFrame({"UMAP 1": emb_2d[:, 0], "UMAP 2": emb_2d[:, 1],
"label": text_df["label"]})
sns.scatterplot(data=plot_df, x="UMAP 1", y="UMAP 2", hue="label", s=80)
plt.show()Four reasonably separated clusters — the model learned, with no supervision, that these symptom families are distinct. Same lift would be there on Cognitive Atlas task descriptions, abstracts of a journal, anything text-shaped.
8. Takeaways
sns.barplot(data=regimes_df, x="regime", y="accuracy")
plt.axhline(0.5, ls="--", c="k", alpha=0.5)
plt.xticks(rotation=15)
plt.show()Borrowing from other people works, but you have to combine it with a little local data. The pretrained EEGNet sits on disk at /content/eegnet_pretrained.pt; the next notebook reloads it to compare against active learning and meta learning.