-
Notifications
You must be signed in to change notification settings - Fork 255
run_for_all_spikes: vectorize per-channel PCA transform #4485
Description
Problem
The inner loop in _all_pc_extractor_chunk (in principal_component.py) calls pca_model[chan_ind].transform() once per spike per channel:
for i in range(i0, i1):
wf = traces[st - start - nbefore : st - start + nafter, :]
...
for c, chan_ind in enumerate(chan_inds):
w = wf[:, chan_ind]
...
all_pcs[i, :, c] = pca_model[chan_ind].transform(w[None, :])For a ~70-minute Neuropixels recording with ~15M spikes and ~20 sparse channels per unit, this results in ~300M individual sklearn transform calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates — the actual linear algebra (a 1×210 by 210×5 matrix multiply) is trivial.
In practice, run_for_all_spikes takes longer than streaming the entire 88 GB recording from S3 and writing it to disk.
Suggested optimization
In by_channel_local mode, the PCA model is per-channel, not per-unit — pca_model[chan_ind] applies to all spikes on that channel regardless of unit. This means all spikes within a chunk that use the same channel can be batched into a single transform call:
# Batch all spikes in the chunk for a given channel
for chan_ind in unique_channels_in_chunk:
# Gather waveform snippets for all spikes using this channel
wfs = np.stack([
traces[st - start - nbefore : st - start + nafter, chan_ind]
for st in spike_times_using_this_channel
]) # shape: (n_spikes_on_channel, 210)
# One transform call instead of n_spikes_on_channel calls
all_pcs[spike_indices, :, channel_position] = pca_model[chan_ind].transform(wfs)This reduces ~300M Python calls to ~384 per chunk (one per channel), with sklearn operating on batch matrices instead of individual rows.
RAM impact
Minimal. For a 1-second chunk with ~3,600 spikes, batching all spikes for one channel is 3,600 × 210 × 4 bytes ≈ 3 MB. Even batching all channels simultaneously would be ~60 MB — well within reason.