Vectorize per-channel PCA transform in run_for_all_spikes#4488
Open
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
Open
Vectorize per-channel PCA transform in run_for_all_spikes#4488galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
run_for_all_spikes#4488galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
Conversation
`_all_pc_extractor_chunk` calls `pca_model[chan_ind].transform()` once per spike per channel in a Python loop. For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels, this is ~260M individual sklearn `transform` calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates. This commit improves performance by batching all spikes within a chunk by channel: 1. Extract all valid waveform snippets in the chunk at once using vectorized fancy indexing 2. Group spikes by channel index across all units 3. Call `pca_model[chan_ind].transform(wfs_batch)` once per channel with the full batch For synthetic data of 500 spikes, 10 channels, with 50-sample waveforms, this improves performance 53x, from 0.126s to 0.002s, with max absolute difference in projections of 9.5e-7. For an integration benchmark extracting PCs from a 5-minute, 379 channel .dat in RAM with 706k spikes and 26 sparse channels x 210-sample waveforms, the vectorization improves performance 5x, from 9.1 minutes to 1.8 minutes. Max absolute difference between the two paths was 1.49e-08.
Instead of iterating over every spike to build the channel-to-spike mapping, loop over unique unit indices and use vectorized boolean masks. Reduces Python loop iterations from n_spikes*n_sparse_channels (~600K) to n_units*n_sparse_channels (~10K) per chunk. 5.0x → 5.6x speedup vs original on real Neuropixels data. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Author
|
The test failure is unrelated. Seems like a problem with cross-correlagram tests and |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
_all_pc_extractor_chunkcallspca_model[chan_ind].transform()once per spike per channel in a Python loop:For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels, this is ~260M individual sklearn
transformcalls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates.Solution
Batch all spikes within a chunk by channel and call
transformonce per channel:pca_model[chan_ind].transform(wfs_batch)once per channel with the full batchIn
by_channel_localmode the PCA model is per-channel (not per-unit), so all spikes on a given channel share the same model regardless of unit identity.Benchmarks
Synthetic (500 spikes, 10 channels, 50-sample waveforms)
Results match: max absolute difference 9.5e-7 (float rounding).
Real data (Neuropixels probe 50213-3, 5 minutes of recording from local .dat)
n_jobs=1,chunk_duration=10sResults match: max absolute difference 1.49e-08.
np.allcloseconfirms identical output.The real-data speedup is lower than synthetic because disk I/O, waveform extraction, and memory allocation are shared costs. The optimization only affects the
transformcall overhead, which is ~80% of chunk time in the original code.Projected impact
For a full 69-minute recording, PC extraction drops from ~60 min to ~12 min.
RAM impact
Minimal. For a 10s chunk with ~25K spikes, the batch waveform array is ~25K × 210 × 4 bytes ≈ 20 MB per channel call, reused across channels. Peak additional memory vs. original: ~50 MB.
Changes
_all_pc_extractor_chunkinprincipal_component.py: replaced per-spike per-channel loop with vectorized batch-by-channel approachtest_principal_component.pytests passFixes #4485
Related: #979