Skip to content

Commit

Permalink
Make LIF cells probeable. (#2021)
Browse files Browse the repository at this point in the history
* Functionality
  * Add probes to LIF cells.
* Docs
  * Remove errorneous statement(s) about LIF cells (there never was an E_reset...)
  * Move probing chapter one level up (concepts/cable_cells -> concepts)
* Tests
  * Add tests for LIF probes
  • Loading branch information
thorstenhater authored Nov 1, 2022
1 parent 8109635 commit 3106ff7
Show file tree
Hide file tree
Showing 20 changed files with 1,803 additions and 97 deletions.
8 changes: 7 additions & 1 deletion arbor/include/arbor/lif_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ struct ARB_SYMBOL_VISIBLE lif_cell {
double C_m = 20; // Membrane capacitance [pF].
double E_L = 0; // Resting potential [mV].
double V_m = E_L; // Initial value of the Membrane potential [mV].
double V_reset = E_L; // Reset potential [mV].
double t_ref = 2; // Refractory period [ms].

lif_cell() = delete;
lif_cell(cell_tag_type source, cell_tag_type target): source(std::move(source)), target(std::move(target)) {}
};

// LIF probe metadata, to be passed to sampler callbacks. Intentionally left blank.
struct ARB_SYMBOL_VISIBLE lif_probe_metadata {};

// Voltage estimate [mV].
// Sample value type: `double`
struct ARB_SYMBOL_VISIBLE lif_probe_voltage {};

} // namespace arb
218 changes: 169 additions & 49 deletions arbor/lif_cell_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "profile/profiler_macro.hpp"
#include "util/rangeutil.hpp"
#include "util/span.hpp"
#include "util/filter.hpp"
#include "util/maputil.hpp"

using namespace arb;

Expand All @@ -13,15 +15,24 @@ lif_cell_group::lif_cell_group(const std::vector<cell_gid_type>& gids, const rec
gids_(gids)
{
for (auto gid: gids_) {
if (!rec.get_probes(gid).empty()) {
throw bad_cell_probe(cell_kind::lif, gid);
auto probes = rec.get_probes(gid);
for (const auto lid: util::count_along(probes)) {
const auto& probe = probes[lid];
if (probe.address.type() == typeid(lif_probe_voltage)) {
cell_member_type id{gid, static_cast<cell_lid_type>(lid)};
probes_[id] = {probe.tag, lif_probe_kind::voltage, {}};
}
else {
throw bad_cell_probe{cell_kind::lif, gid};
}
}
}
// Default to no binning of events
lif_cell_group::set_binning_policy(binning_kind::none, 0);

cells_.reserve(gids_.size());
last_time_updated_.resize(gids_.size());
next_time_updatable_.resize(gids_.size());

for (auto lid: util::make_span(gids_.size())) {
cells_.push_back(util::any_cast<lif_cell>(rec.get_cell_description(gids_[lid])));
Expand All @@ -41,11 +52,9 @@ cell_kind lif_cell_group::get_cell_kind() const {

void lif_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) {
PE(advance:lif);
if (event_lanes.size() > 0) {
for (auto lid: util::make_span(gids_.size())) {
// Advance each cell independently.
advance_cell(ep.t1, dt, lid, event_lanes[lid]);
}
for (auto lid: util::make_span(gids_.size())) {
// Advance each cell independently.
advance_cell(ep.t1, dt, lid, event_lanes);
}
PL();
}
Expand All @@ -59,10 +68,30 @@ void lif_cell_group::clear_spikes() {
}

// TODO: implement sampler
void lif_cell_group::add_sampler(sampler_association_handle h, cell_member_predicate probeset_ids,
schedule sched, sampler_function fn, sampling_policy policy) {}
void lif_cell_group::remove_sampler(sampler_association_handle h) {}
void lif_cell_group::remove_all_samplers() {}
void lif_cell_group::add_sampler(sampler_association_handle h,
cell_member_predicate probeset_ids,
schedule sched,
sampler_function fn,
sampling_policy policy) {
std::lock_guard<std::mutex> guard(sampler_mex_);
std::vector<cell_member_type> probeset =
util::assign_from(util::filter(util::keys(probes_), probeset_ids));
auto assoc = arb::sampler_association{std::move(sched),
std::move(fn),
std::move(probeset),
policy};
auto result = samplers_.insert({h, std::move(assoc)});
arb_assert(result.second);
}

void lif_cell_group::remove_sampler(sampler_association_handle h) {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.erase(h);
}
void lif_cell_group::remove_all_samplers() {
std::lock_guard<std::mutex> guard(sampler_mex_);
samplers_.clear();
}

// TODO: implement binner_
void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_interval) {
Expand All @@ -71,52 +100,143 @@ void lif_cell_group::set_binning_policy(binning_kind policy, time_type bin_inter
void lif_cell_group::reset() {
spikes_.clear();
util::fill(last_time_updated_, 0.);
util::fill(next_time_updatable_, 0.);
}

// Advances a single cell (lid) with the exact solution (jumps can be arbitrary).
// Parameter dt is ignored, since we make jumps between two consecutive spikes.
void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane) {
// Current time of last update.
auto t = last_time_updated_[lid];
void lif_cell_group::advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lanes) {
const auto gid = gids_[lid];
auto& cell = cells_[lid];
const auto n_events = event_lane.size();

// Integrate until tfinal using the exact solution of membrane voltage differential equation.
for (unsigned i=0; i<n_events; ++i ) {
auto& ev = event_lane[i];
const auto time = ev.time;
auto weight = ev.weight;

if (time < t) continue; // skip event if a neuron is in refactory period
if (time >= tfinal) break; // end of integration interval

// if there are events that happened at the same time as this event, process them as well
while (i + 1 < n_events && event_lane[i+1].time <= time) {
weight += event_lane[i+1].weight;
i++;
// time of last update.
auto t = last_time_updated_[lid];
// spikes to process
const auto n_events = static_cast<int>(event_lanes.size() ? event_lanes[lid].size() : 0);
int event_idx = 0;
// collected sampling data
std::unordered_map<sampler_association_handle,
std::unordered_map<cell_member_type,
std::vector<sample_record>>> sampled;
// samples to process
std::size_t n_values = 0;
std::vector<std::pair<time_type, sampler_association_handle>> samples;
{
std::lock_guard<std::mutex> guard(sampler_mex_);
for (auto& [hdl, assoc]: samplers_) {
// Construct sampling times
const auto& times = util::make_range(assoc.sched.events(t, tfinal));
const auto n_times = times.size();
// Count up the samplers touching _our_ gid
int delta = 0;
for (const auto& pid: assoc.probeset_ids) {
if (pid.gid != gid) continue;
arb_assert (0 == sampled[hdl].count(pid));
sampled[hdl][pid].reserve(n_times);
delta += n_times;
}
if (delta == 0) continue;
n_values += delta;
// only exact sampling: ignore lax and never look at policy
for (auto t: times) samples.emplace_back(t, hdl);
}
}
std::sort(samples.begin(), samples.end());
int n_samples = samples.size();
int sample_idx = 0;
// Now allocate some scratch space for the probed values, if we don't,
// re-alloc might move our data
std::vector<value_type> sampled_voltages;
sampled_voltages.reserve(n_values);
// integrate until tfinal using the exact solution of membrane voltage differential equation.
for (;;) {
const auto event_time = event_idx < n_events ? event_lanes[lid][event_idx].time : tfinal;
const auto sample_time = sample_idx < n_samples ? samples[sample_idx].first : tfinal;
const auto time = std::min(event_time, sample_time);
// bail at end of integration interval
if (time >= tfinal) break;
// Check what to do, we might need to process events **and/or** perform
// sampling.
// NB. we put events before samples, if they collide we'll see
// the update in sampling.

bool do_event = time == event_time;
bool do_sample = time == sample_time;

if (do_event) {
const auto& event_lane = event_lanes[lid];
// process all events at time t
auto weight = 0.0;
for (; event_idx < n_events && event_lane[event_idx].time <= time; ++event_idx) {
weight += event_lane[event_idx].weight;
}
// skip event if neuron is in refactory period
if (time >= t) {
// Let the membrane potential decay.
cell.V_m *= exp((t - time) / cell.tau_m);
// Add jump due to spike(s).
cell.V_m += weight / cell.C_m;
// Update current time
t = time;
// If crossing threshold occurred
if (cell.V_m >= cell.V_th) {
// save spike
spikes_.push_back({{gid, 0}, time});
// Advance to account for the refractory period.
// This means decay will also start at t + t_ref
t += cell.t_ref;
// Reset the voltage to resting potential.
cell.V_m = cell.E_L;
}
}
}

// Let the membrane potential decay.
auto decay = exp(-(time - t) / cell.tau_m);
cell.V_m *= decay;
auto update = weight / cell.C_m;
// Add jump due to spike.
cell.V_m += update;
t = time;
// If crossing threshold occurred
if (cell.V_m >= cell.V_th) {
cell_member_type spike_neuron_gid = {gids_[lid], 0};
spike s = {spike_neuron_gid, t};
spikes_.push_back(s);

// Advance the last_time_updated to account for the refractory period.
t += cell.t_ref;

// Reset the voltage to resting potential.
cell.V_m = cell.E_L;
if (do_sample) {
// Consume all sample events at this time
for (; sample_idx < n_samples && samples[sample_idx].first <= time; ++sample_idx) {
const auto& [s_time, hdl] = samples[sample_idx];
for (const auto& key: samplers_[hdl].probeset_ids) {
const auto& kind = probes_.at(key).kind;
// This is the only thing we know how to do: Probing U(t)
switch (kind) {
case lif_probe_kind::voltage: {
// Compute, but do not _set_ V_m
auto U = cell.V_m;
if (time >= t) U *= exp((t - time) / cell.tau_m);
// Store U for later use.
sampled_voltages.push_back(U);
// Set up reference to sampled value
sampled[hdl][key].push_back(sample_record{time, {&sampled_voltages.back()}});
break;
}
default:
throw arbor_internal_error{"Invalid LIF probe kind"};
}
}
}
}
if (!(do_sample || do_event)) {
throw arbor_internal_error{"LIF cell group: Must select either sample or spike event; got neither."};
}
last_time_updated_[lid] = t;
}
arb_assert (sampled_voltages.size() == n_values);
// Now we need to call all sampler callbacks with the data we have collected
{
std::lock_guard<std::mutex> guard(sampler_mex_);
for (const auto& [k, vs]: sampled) {
const auto& fun = samplers_[k].sampler;
for (const auto& [id, us]: vs) {
auto meta = get_probe_metadata(id)[0];
fun(meta, us.size(), us.data());
}
}
}
}

// This is the last time a cell was updated.
last_time_updated_[lid] = t;
std::vector<probe_metadata> lif_cell_group::get_probe_metadata(cell_member_type key) const {
if (probes_.count(key)) {
return {probe_metadata{key, {}, 0, {&probes_.at(key).metadata}}};
} else {
return {};
}
}
25 changes: 24 additions & 1 deletion arbor/lif_cell_group.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <mutex>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
Expand All @@ -9,6 +10,7 @@
#include <arbor/sampling.hpp>
#include <arbor/spike.hpp>

#include "sampler_map.hpp"
#include "cell_group.hpp"
#include "label_resolution.hpp"

Expand Down Expand Up @@ -37,10 +39,21 @@ class ARB_ARBOR_API lif_cell_group: public cell_group {
virtual void remove_sampler(sampler_association_handle) override;
virtual void remove_all_samplers() override;

virtual std::vector<probe_metadata> get_probe_metadata(cell_member_type) const override;

private:
enum class lif_probe_kind { voltage };

struct lif_probe_info {
probe_tag tag;
lif_probe_kind kind;
lif_probe_metadata metadata;
};


// Advances a single cell (lid) with the exact solution (jumps can be arbitrary).
// Parameter dt is ignored, since we make jumps between two consecutive spikes.
void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, pse_vector& event_lane);
void advance_cell(time_type tfinal, time_type dt, cell_gid_type lid, const event_lane_subrange& event_lane);

// List of the gids of the cells in the group.
std::vector<cell_gid_type> gids_;
Expand All @@ -53,6 +66,16 @@ class ARB_ARBOR_API lif_cell_group: public cell_group {

// Time when the cell was last updated.
std::vector<time_type> last_time_updated_;
// Time when the cell can _next_ be updated;
std::vector<time_type> next_time_updatable_;

// SAFETY: We need to access samplers_ through a mutex since
// simulation::add_sampler might be called concurrently.
std::mutex sampler_mex_;
sampler_association_map samplers_;

// LIF probe metadata, precalculated to pass to callbacks
std::unordered_map<cell_member_type, lif_probe_info> probes_;
};

} // namespace arb
1 change: 0 additions & 1 deletion doc/concepts/cable_cell.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ Once constructed, the cable cell can be queried for specific information about t
labels
mechanisms
decor
probe_sample

API
---
Expand Down
2 changes: 2 additions & 0 deletions doc/concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ of the model over the locally available computational resources.

In order to visualize the result of detected spikes a spike recorder can be used, and to analyse Arbor's performance a
meter manager is available.

:ref:`probesample` shows how to extract data from simulations.
Loading

0 comments on commit 3106ff7

Please sign in to comment.