Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,27 +633,65 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
end = int(spike_times[i1 - 1] + nafter)
traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index)

for i in range(i0, i1):
st = spike_times[i]
if st - start - nbefore < 0:
continue
if st - start + nafter > traces.shape[0]:
continue
nsamples = nbefore + nafter

wf = traces[st - start - nbefore : st - start + nafter, :]
# Extract all waveforms in the chunk at once
# valid_mask tracks which spikes have valid (in-bounds) waveforms
chunk_spike_times = spike_times[i0:i1]
offsets = chunk_spike_times - start - nbefore
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me probably misreading but I'm super bad at parsing these type of > and < in general. If we have to be less than or = to the shape couldn't we run into an issue where we are = to the shape which is out of bounds?

ie an array of (4,5) the shape[0] = 4, but if I try to index on 4 it will be an out of bounds error. Again I don't work on the PC code at all so maybe I'm completely wrong here.


unit_index = spike_labels[i]
chan_inds = unit_channels[unit_index]
if not np.any(valid_mask):
return

valid_offsets = offsets[valid_mask]
valid_indices = np.arange(i0, i1)[valid_mask]
n_valid = len(valid_offsets)

# Build waveform array: (n_valid, nsamples, n_channels)
# Use fancy indexing to extract all snippets at once
sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples)
all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels)

# Vectorized PCA: batch by channel across all spikes in the chunk.
# For each unique channel, find all spikes that use it (via their unit's
# sparsity), extract waveforms, and call transform once.
valid_labels = spike_labels[valid_indices]

# Build a set of all channels used by spikes in this chunk
unique_unit_indices = np.unique(valid_labels)
chan_info: dict[int, list[tuple[np.ndarray, int]]] = {}
for unit_index in unique_unit_indices:
chan_inds = unit_channels[unit_index]
unit_mask = valid_labels == unit_index
unit_local_idxs = np.nonzero(unit_mask)[0]
for c, chan_ind in enumerate(chan_inds):
w = wf[:, chan_ind]
if w.size > 0:
w = w[None, :]
try:
all_pcs[i, :, c] = pca_model[chan_ind].transform(w)
except:
# this could happen if len(wfs) is less then n_comp for a channel
pass
if chan_ind not in chan_info:
chan_info[chan_ind] = []
chan_info[chan_ind].append((unit_local_idxs, c))

for chan_ind, unit_groups in chan_info.items():
# Concatenate all spike indices for this channel across units
all_local_idxs = np.concatenate([g[0] for g in unit_groups])
global_idxs = valid_indices[all_local_idxs]

# Batch waveforms for this channel: (n_spikes, nsamples)
wfs_batch = all_wfs[all_local_idxs, :, chan_ind]

if wfs_batch.size == 0:
continue

try:
pcs_batch = pca_model[chan_ind].transform(wfs_batch)
# Write results back — each unit group has a fixed channel position
offset = 0
for unit_local_idxs, c_pos in unit_groups:
n = len(unit_local_idxs)
all_pcs[global_idxs[offset : offset + n], :, c_pos] = pcs_batch[offset : offset + n]
offset += n
except Exception:
# this could happen if len(wfs) is less than n_comp for a channel
pass


def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model):
Expand Down
Loading