Skip to content

Vectorize per-channel PCA transform in run_for_all_spikes#4488

Open
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:feat/vectorize-per-channel-pca
Open

Vectorize per-channel PCA transform in run_for_all_spikes#4488
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:feat/vectorize-per-channel-pca

Conversation

@galenlynch
Copy link
Copy Markdown

@galenlynch galenlynch commented Apr 1, 2026

Problem

_all_pc_extractor_chunk calls pca_model[chan_ind].transform() once per spike per channel in a Python loop:

for i in range(i0, i1):
    wf = traces[st - start - nbefore : st - start + nafter, :]
    for c, chan_ind in enumerate(chan_inds):
        w = wf[:, chan_ind]
        all_pcs[i, :, c] = pca_model[chan_ind].transform(w[None, :])

For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels, this is ~260M individual sklearn transform calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates.

Solution

Batch all spikes within a chunk by channel and call transform once per channel:

  1. Extract all valid waveform snippets in the chunk at once using vectorized fancy indexing
  2. Group spikes by channel index across all units
  3. Call pca_model[chan_ind].transform(wfs_batch) once per channel with the full batch

In by_channel_local mode the PCA model is per-channel (not per-unit), so all spikes on a given channel share the same model regardless of unit identity.

Benchmarks

Synthetic (500 spikes, 10 channels, 50-sample waveforms)

Time Speedup
Original 0.126s
Vectorized 0.002s 53x

Results match: max absolute difference 9.5e-7 (float rounding).

Real data (Neuropixels probe 50213-3, 5 minutes of recording from local .dat)

  • 706K total spikes
  • 379 channels, 26 sparse channels per unit, 210-sample waveforms
  • n_jobs=1, chunk_duration=10s
Time Per chunk Speedup
Original 548.7s (9.1 min) 18.3s/chunk
Vectorized 110.3s (1.8 min) 2.1s/chunk 5.0x
Vectorized + numpy grouping 108.2s 2.0s/chunk 5.6x

Results match: max absolute difference 1.49e-08. np.allclose confirms identical output.

The real-data speedup is lower than synthetic because disk I/O, waveform extraction, and memory allocation are shared costs. The optimization only affects the transform call overhead, which is ~80% of chunk time in the original code.

Projected impact

For a full 69-minute recording, PC extraction drops from ~60 min to ~12 min.

RAM impact

Minimal. For a 10s chunk with ~25K spikes, the batch waveform array is ~25K × 210 × 4 bytes ≈ 20 MB per channel call, reused across channels. Peak additional memory vs. original: ~50 MB.

Changes

  • _all_pc_extractor_chunk in principal_component.py: replaced per-spike per-channel loop with vectorized batch-by-channel approach
  • No API changes, no new dependencies
  • All existing test_principal_component.py tests pass
  • A follow-up optimization replaces the per-spike Python dict loop (building the channel→spike mapping) with numpy unit grouping: loop over unique unit indices and use vectorized boolean masks. This reduces Python iterations from n_spikes × n_sparse_channels (~600K) to n_units × n_sparse_channels (~10K) per chunk.

Fixes #4485
Related: #979

galenlynch and others added 2 commits April 1, 2026 10:43
`_all_pc_extractor_chunk` calls `pca_model[chan_ind].transform()` once per spike
per channel in a Python loop.

For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels,
this is ~260M individual sklearn `transform` calls, each on a 1×210 matrix. The
Python/sklearn per-call overhead dominates.

This commit improves performance by batching all spikes within a chunk by
channel:

1. Extract all valid waveform snippets in the chunk at once using vectorized
   fancy indexing
2. Group spikes by channel index across all units
3. Call `pca_model[chan_ind].transform(wfs_batch)` once per channel with the
   full batch

For synthetic data of 500 spikes, 10 channels, with 50-sample waveforms, this
improves performance 53x, from 0.126s to 0.002s, with max absolute difference in
projections of 9.5e-7.

For an integration benchmark extracting PCs from a 5-minute, 379 channel .dat in
RAM with 706k spikes and 26 sparse channels x 210-sample waveforms, the
vectorization improves performance 5x, from 9.1 minutes to 1.8 minutes. Max
absolute difference between the two paths was 1.49e-08.
Instead of iterating over every spike to build the channel-to-spike
mapping, loop over unique unit indices and use vectorized boolean
masks. Reduces Python loop iterations from n_spikes*n_sparse_channels
(~600K) to n_units*n_sparse_channels (~10K) per chunk.

5.0x → 5.6x speedup vs original on real Neuropixels data.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@galenlynch
Copy link
Copy Markdown
Author

The test failure is unrelated. Seems like a problem with cross-correlagram tests and fast_mode?

@alejoe91 alejoe91 added postprocessing Related to postprocessing module performance Performance issues/improvements labels Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues/improvements postprocessing Related to postprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

run_for_all_spikes: vectorize per-channel PCA transform

2 participants