Skip to content

run_for_all_spikes: vectorize per-channel PCA transform #4485

@galenlynch

Description

@galenlynch

Problem

The inner loop in _all_pc_extractor_chunk (in principal_component.py) calls pca_model[chan_ind].transform() once per spike per channel:

for i in range(i0, i1):
    wf = traces[st - start - nbefore : st - start + nafter, :]
    ...
    for c, chan_ind in enumerate(chan_inds):
        w = wf[:, chan_ind]
        ...
        all_pcs[i, :, c] = pca_model[chan_ind].transform(w[None, :])

For a ~70-minute Neuropixels recording with ~15M spikes and ~20 sparse channels per unit, this results in ~300M individual sklearn transform calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates — the actual linear algebra (a 1×210 by 210×5 matrix multiply) is trivial.

In practice, run_for_all_spikes takes longer than streaming the entire 88 GB recording from S3 and writing it to disk.

Suggested optimization

In by_channel_local mode, the PCA model is per-channel, not per-unit — pca_model[chan_ind] applies to all spikes on that channel regardless of unit. This means all spikes within a chunk that use the same channel can be batched into a single transform call:

# Batch all spikes in the chunk for a given channel
for chan_ind in unique_channels_in_chunk:
    # Gather waveform snippets for all spikes using this channel
    wfs = np.stack([
        traces[st - start - nbefore : st - start + nafter, chan_ind]
        for st in spike_times_using_this_channel
    ])  # shape: (n_spikes_on_channel, 210)

    # One transform call instead of n_spikes_on_channel calls
    all_pcs[spike_indices, :, channel_position] = pca_model[chan_ind].transform(wfs)

This reduces ~300M Python calls to ~384 per chunk (one per channel), with sklearn operating on batch matrices instead of individual rows.

RAM impact

Minimal. For a 1-second chunk with ~3,600 spikes, batching all spikes for one channel is 3,600 × 210 × 4 bytes ≈ 3 MB. Even batching all channels simultaneously would be ~60 MB — well within reason.

Metadata

Metadata

Assignees

No one assigned

    Labels

    performancePerformance issues/improvementspostprocessingRelated to postprocessing module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions