Skip to content

Commit

Permalink
added convenience function to detection that behaves like s SI sortin…
Browse files Browse the repository at this point in the history
…g component
  • Loading branch information
mhhennig committed Jul 25, 2024
1 parent e5e6fe2 commit f4113b4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
4 changes: 2 additions & 2 deletions herdingspikes/detection_lightning/detect.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ class HSDetectionLightning(object):
)

result: dict[str, RealArray] = {
"sample_ind": sample_ind,
"channel_ind": channel_ind,
"sample_index": sample_ind,
"channel_index": channel_ind,
"amplitude": amplitude,
}
if self.localize:
Expand Down
51 changes: 49 additions & 2 deletions herdingspikes/hs2.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ def DetectFromRaw(self):
sp[0]["spike_shape"] = np.zeros(len(sp[0]["sample_ind"]))
self.spikes = pd.DataFrame(
{
"ch": sp[0]["channel_ind"],
"t": sp[0]["sample_ind"],
"ch": sp[0]["channel_index"],
"t": sp[0]["sample_index"],
"Amplitude": sp[0]["amplitude"],
"x": sp[0]["location"][:, 0],
"y": sp[0]["location"][:, 1],
Expand All @@ -540,6 +540,7 @@ def DetectFromRaw(self):
else:
h["cutout_length"] = 0
h.close()
return sp

def PlotTracesChannels(
self,
Expand Down Expand Up @@ -1432,3 +1433,49 @@ def PlotNeighbourhood(
ax[1].plot(self.spikes.Shape[i], color=(0.4, 0.4, 0.4))
ax[1].plot(np.mean(self.spikes.Shape[spInds].values, axis=0), color="k")
return ax


def detect_peaks_lightning(recording, params=None):
"""
Detect spikes in a recording using the lightning framework. This function is compatible
with the SpikeInterface sorting components framework. Note it does not return spike locations.
Parameters
----------
recording : RecordingExtractor
The recording extractor object
params : dict
The parameters for the spike detection. If None, default parameters are used.
Returns
-------
peaks : np.array
Structured array with the detected peaks. Fields are:
* 'sample_index' : int
The index of the peak sample
* 'channel_index' : int
The index of the channel
* 'amplitude' : float
The amplitude of the peak
"""

det = HSDetectionLightning(recording, params=params)
peaks = det.DetectFromRaw()
peaks_array = np.array(
list(
tuple(
map(
tuple,
np.array(
[
(peaks[0][k])
for k in ["sample_index", "channel_index", "amplitude"]
]
).T,
)
)
),
dtype=[("sample_index", "<i8"), ("channel_index", "<i8"), ("amplitude", "<f8")],
)
return peaks_array

0 comments on commit f4113b4

Please sign in to comment.