From 64034b76ac02c453446deb88817c0f599f5a7e5a Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Wed, 1 Apr 2026 09:05:41 -0700 Subject: [PATCH 1/2] feat: vectorize per-channel PCA transform in `run_for_all_spikes` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `_all_pc_extractor_chunk` calls `pca_model[chan_ind].transform()` once per spike per channel in a Python loop. For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels, this is ~260M individual sklearn `transform` calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates. This commit improves performance by batching all spikes within a chunk by channel: 1. Extract all valid waveform snippets in the chunk at once using vectorized fancy indexing 2. Group spikes by channel index across all units 3. Call `pca_model[chan_ind].transform(wfs_batch)` once per channel with the full batch For synthetic data of 500 spikes, 10 channels, with 50-sample waveforms, this improves performance 53x, from 0.126s to 0.002s, with max absolute difference in projections of 9.5e-7. For an integration benchmark extracting PCs from a 5-minute, 379 channel .dat in RAM with 706k spikes and 26 sparse channels x 210-sample waveforms, the vectorization improves performance 5x, from 9.1 minutes to 1.8 minutes. Max absolute difference between the two paths was 1.49e-08. --- .../postprocessing/principal_component.py | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index bb48a08e64..a84ea6dc28 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -633,27 +633,60 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): end = int(spike_times[i1 - 1] + nafter) traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index) - for i in range(i0, i1): - st = spike_times[i] - if st - start - nbefore < 0: - continue - if st - start + nafter > traces.shape[0]: - continue + nsamples = nbefore + nafter - wf = traces[st - start - nbefore : st - start + nafter, :] + # Extract all waveforms in the chunk at once + # valid_mask tracks which spikes have valid (in-bounds) waveforms + chunk_spike_times = spike_times[i0:i1] + offsets = chunk_spike_times - start - nbefore + valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0]) - unit_index = spike_labels[i] - chan_inds = unit_channels[unit_index] + if not np.any(valid_mask): + return + + valid_offsets = offsets[valid_mask] + valid_indices = np.arange(i0, i1)[valid_mask] + n_valid = len(valid_offsets) + + # Build waveform array: (n_valid, nsamples, n_channels) + # Use fancy indexing to extract all snippets at once + sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples) + all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels) + + # Vectorized PCA: batch by channel across all spikes in the chunk + # Build a mapping: for each channel, which spikes use it and at what position + valid_labels = spike_labels[valid_indices] + # Collect (spike_local_idx, channel_position, channel_index) for all spike-channel pairs + chan_to_spikes: dict[int, list[tuple[int, int]]] = {} + for local_idx in range(n_valid): + unit_index = valid_labels[local_idx] + chan_inds = unit_channels[unit_index] for c, chan_ind in enumerate(chan_inds): - w = wf[:, chan_ind] - if w.size > 0: - w = w[None, :] - try: - all_pcs[i, :, c] = pca_model[chan_ind].transform(w) - except: - # this could happen if len(wfs) is less then n_comp for a channel - pass + if chan_ind not in chan_to_spikes: + chan_to_spikes[chan_ind] = [] + chan_to_spikes[chan_ind].append((local_idx, c)) + + for chan_ind, spike_chan_pairs in chan_to_spikes.items(): + local_idxs = np.array([p[0] for p in spike_chan_pairs]) + chan_positions = np.array([p[1] for p in spike_chan_pairs]) + global_idxs = valid_indices[local_idxs] + + # Batch waveforms for this channel: (n_spikes, nsamples) + wfs_batch = all_wfs[local_idxs, :, chan_ind] + + if wfs_batch.size == 0: + continue + + try: + pcs_batch = pca_model[chan_ind].transform(wfs_batch) # (n_spikes, n_components) + # Write results — group by channel position to use vectorized indexing + for c_pos in np.unique(chan_positions): + mask = chan_positions == c_pos + all_pcs[global_idxs[mask], :, c_pos] = pcs_batch[mask] + except Exception: + # this could happen if len(wfs) is less than n_comp for a channel + pass def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): From 34484520f381ee45004e3f754bf837c6b1d13d28 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Wed, 1 Apr 2026 14:30:49 -0700 Subject: [PATCH 2/2] perf: replace per-spike dict loop with numpy unit grouping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of iterating over every spike to build the channel-to-spike mapping, loop over unique unit indices and use vectorized boolean masks. Reduces Python loop iterations from n_spikes*n_sparse_channels (~600K) to n_units*n_sparse_channels (~10K) per chunk. 5.0x → 5.6x speedup vs original on real Neuropixels data. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../postprocessing/principal_component.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index a84ea6dc28..785b47c787 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -653,37 +653,42 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples) all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels) - # Vectorized PCA: batch by channel across all spikes in the chunk - # Build a mapping: for each channel, which spikes use it and at what position + # Vectorized PCA: batch by channel across all spikes in the chunk. + # For each unique channel, find all spikes that use it (via their unit's + # sparsity), extract waveforms, and call transform once. valid_labels = spike_labels[valid_indices] - # Collect (spike_local_idx, channel_position, channel_index) for all spike-channel pairs - chan_to_spikes: dict[int, list[tuple[int, int]]] = {} - for local_idx in range(n_valid): - unit_index = valid_labels[local_idx] + # Build a set of all channels used by spikes in this chunk + unique_unit_indices = np.unique(valid_labels) + chan_info: dict[int, list[tuple[np.ndarray, int]]] = {} + for unit_index in unique_unit_indices: chan_inds = unit_channels[unit_index] + unit_mask = valid_labels == unit_index + unit_local_idxs = np.nonzero(unit_mask)[0] for c, chan_ind in enumerate(chan_inds): - if chan_ind not in chan_to_spikes: - chan_to_spikes[chan_ind] = [] - chan_to_spikes[chan_ind].append((local_idx, c)) + if chan_ind not in chan_info: + chan_info[chan_ind] = [] + chan_info[chan_ind].append((unit_local_idxs, c)) - for chan_ind, spike_chan_pairs in chan_to_spikes.items(): - local_idxs = np.array([p[0] for p in spike_chan_pairs]) - chan_positions = np.array([p[1] for p in spike_chan_pairs]) - global_idxs = valid_indices[local_idxs] + for chan_ind, unit_groups in chan_info.items(): + # Concatenate all spike indices for this channel across units + all_local_idxs = np.concatenate([g[0] for g in unit_groups]) + global_idxs = valid_indices[all_local_idxs] # Batch waveforms for this channel: (n_spikes, nsamples) - wfs_batch = all_wfs[local_idxs, :, chan_ind] + wfs_batch = all_wfs[all_local_idxs, :, chan_ind] if wfs_batch.size == 0: continue try: - pcs_batch = pca_model[chan_ind].transform(wfs_batch) # (n_spikes, n_components) - # Write results — group by channel position to use vectorized indexing - for c_pos in np.unique(chan_positions): - mask = chan_positions == c_pos - all_pcs[global_idxs[mask], :, c_pos] = pcs_batch[mask] + pcs_batch = pca_model[chan_ind].transform(wfs_batch) + # Write results back — each unit group has a fixed channel position + offset = 0 + for unit_local_idxs, c_pos in unit_groups: + n = len(unit_local_idxs) + all_pcs[global_idxs[offset : offset + n], :, c_pos] = pcs_batch[offset : offset + n] + offset += n except Exception: # this could happen if len(wfs) is less than n_comp for a channel pass