diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index bb48a08e64..785b47c787 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -633,27 +633,65 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): end = int(spike_times[i1 - 1] + nafter) traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index) - for i in range(i0, i1): - st = spike_times[i] - if st - start - nbefore < 0: - continue - if st - start + nafter > traces.shape[0]: - continue + nsamples = nbefore + nafter - wf = traces[st - start - nbefore : st - start + nafter, :] + # Extract all waveforms in the chunk at once + # valid_mask tracks which spikes have valid (in-bounds) waveforms + chunk_spike_times = spike_times[i0:i1] + offsets = chunk_spike_times - start - nbefore + valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0]) - unit_index = spike_labels[i] - chan_inds = unit_channels[unit_index] + if not np.any(valid_mask): + return + + valid_offsets = offsets[valid_mask] + valid_indices = np.arange(i0, i1)[valid_mask] + n_valid = len(valid_offsets) + + # Build waveform array: (n_valid, nsamples, n_channels) + # Use fancy indexing to extract all snippets at once + sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples) + all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels) + + # Vectorized PCA: batch by channel across all spikes in the chunk. + # For each unique channel, find all spikes that use it (via their unit's + # sparsity), extract waveforms, and call transform once. + valid_labels = spike_labels[valid_indices] + # Build a set of all channels used by spikes in this chunk + unique_unit_indices = np.unique(valid_labels) + chan_info: dict[int, list[tuple[np.ndarray, int]]] = {} + for unit_index in unique_unit_indices: + chan_inds = unit_channels[unit_index] + unit_mask = valid_labels == unit_index + unit_local_idxs = np.nonzero(unit_mask)[0] for c, chan_ind in enumerate(chan_inds): - w = wf[:, chan_ind] - if w.size > 0: - w = w[None, :] - try: - all_pcs[i, :, c] = pca_model[chan_ind].transform(w) - except: - # this could happen if len(wfs) is less then n_comp for a channel - pass + if chan_ind not in chan_info: + chan_info[chan_ind] = [] + chan_info[chan_ind].append((unit_local_idxs, c)) + + for chan_ind, unit_groups in chan_info.items(): + # Concatenate all spike indices for this channel across units + all_local_idxs = np.concatenate([g[0] for g in unit_groups]) + global_idxs = valid_indices[all_local_idxs] + + # Batch waveforms for this channel: (n_spikes, nsamples) + wfs_batch = all_wfs[all_local_idxs, :, chan_ind] + + if wfs_batch.size == 0: + continue + + try: + pcs_batch = pca_model[chan_ind].transform(wfs_batch) + # Write results back — each unit group has a fixed channel position + offset = 0 + for unit_local_idxs, c_pos in unit_groups: + n = len(unit_local_idxs) + all_pcs[global_idxs[offset : offset + n], :, c_pos] = pcs_batch[offset : offset + n] + offset += n + except Exception: + # this could happen if len(wfs) is less than n_comp for a channel + pass def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model):