Skip to content

lynnsoerensen/SpikingRecurrencyAPI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

How to accelerate deep spiking neural networks with recurrent API from Keras/Tensorflow

by Lynn K.A. Sörensen, Bojian Yin, H.Steven Scholte & Sander M. Bohté

Background

Spiking deep neural networks are attractive for deep learning for their biological realism and sparse activations, making them more efficient in how they represent activation. Recently, it has become possible to convert common deep convolutional neural networks trained with a specific activation function instead of a ReLU into a deep spiking network after training (e.g. Zambrano et al. 2018; Rueckauer et al. 2016). Developments like this fuel the surge in new applications using spiking deep neural networks for artificial vision and event-based processing (e.g. Mueggler, Huber & Scaramuzza, 2014).

While there is a range of APIs that excel at capturing the complexity of spiking neurons such as GeNN, Nest or Nengo, these frameworks are not yet integrated with the most commonly used deep learning APIs. Here, we show how you can implement any spiking neuron in a deep neural network leveraging recurrent layer routines from established deep learning APIs such as Keras/Tensorflow to effectively speed-up computation time.

alt text

Approach

Neuron models

A spiking neuron maintains a number of parameters over time such as the integrated current and its firing threshold in the case of an adaptive firing neuron. Because all of these parameters decay over time as well as respond to new incoming activations, the states of these parameters have to be maintained from one timesteps to another. Algorithmically, this can result in a for-loop in which every time step is being computed sequentially (but see causal convolutions). These kind of computations are not well-suited for GPUs since they cannot be parallelized and require a lot of memory.

This is an example of one such implementation for a spiking neuron in Numpy.

class ASN:
    """ Adaptive spiking neuron class """
    def __init__(self, mf = 0.1, bias = 0):
        # Params of a spiking neuron
        # membrane filter
        self.tau_phi =2.5
        self.dPhi = np.exp(-1 / self.tau_phi)
        # threshold decay filter
        self.tau_gamma = 15.0
        self.dGamma = np.exp(-1 / self.tau_gamma)
        # refractory decay  filter
        self.tau_eta = 50.0
        self.dEta = np.exp(-1 / self.tau_eta)
        self.dBeta = self.dEta

        self.m_f = mf 
        self.theta0 = self.m_f  # Resting threshold

        self.S_bias = bias,
        self.S = self.S_bias  # filtered activation, initialized with bias
        self.S_dyn = 0
        self.theta =self.theta0  # Start value of thresehold
        self.theta_dyn = 0  # dynamic part of the threshold
        self.S_hat = 0  # refractory response, internal approximation

        self.I = 0
        self.spike = 0

    def update(self ,current ,spike_train = True):
        """inject current for one moment in time at once"""
        # Membrane filter
        if spike_train == True:
            self.I = self.I * self.dBeta + current
        else:
            self.I = current
        self.S_dyn =(1 - self.dPhi) * self.I + self.dPhi * self.S_dyn
        self.S = self.S_bias + self.S_dyn
        # Decay
        self.S_hat = self.S_hat * self.dEta
        
        # Spike?
        if self.S - self.S_hat > 0.5 * self.theta:
            self.spike = 1  # Code spike

            # Update refractory response
            self.S_hat = self.S_hat + self.theta

            # Update threshold
            self.theta_dyn = self.theta_dyn + self.m_f * self.theta  # adaptive part based on the paper

         else:
            self.spike = 0

        # Decay
        self.theta_dyn = self.theta_dyn * self.dGamma
        self.theta = self.theta0 + self.theta_dyn


    def call(self, input, spike_train=True, mf=0.1, bias=0):
        timesteps = input.shape[1]
        batch_size = input.shape[0]
        spikes = np.zeros(input.shape)

        for b in range(batch_size):
            self.__init__(mf=mf, bias=bias)
            for t in range(timesteps):  # loop over timesteps
                self.update(input[b, t, :], spike_train=spike_train)
                spikes[b, t, 0] = self.spike

        return spikes 

Recurrent API

Recurrent layer processing routines found in deep learning toolboxes such as Keras/Tensorflow and PyTorch offer a more optimized way to perform these computations on a GPU. The key here is that only the preceding state of the parameters has to be kept in memory on the GPU making it a more light-weight operation and that by expressing the spiking computations in a recurrent routine, they can be optimized for GPU computation.

Below is the call function for a Spiking neuron layer in Tensorflow/Keras.

    def call(self, inputs, mask=None):

        batch_size = K.shape(inputs)[0]

        # Preallocate states
        I = tf.zeros((batch_size, self.units))
        S_dyn = tf.zeros((batch_size, self.units))  # dynamic part of the activation
        theta_dyn = tf.zeros((batch_size, self.units))  # dynamic part of the threshold
        S_hat = tf.zeros((batch_size, self.units))  # refractory response, internal approximation

        # Loop over all time points
        last_output, outputs, states = K.rnn(self.update,
                                             inputs,
                                             [I, theta_dyn, S_dyn, S_hat],
                                             unroll=False,
                                             input_length=K.int_shape(inputs)[1])

        return outputs
        

The core of the spiking neuron can be found in the step function, which is passed to keras.backend.rnn().

    def update(self, current, states):
        """inject current for one moment in time at once"""
        # states: [I, theta_dyn, S_dyn, S_hat]

        I = states[0]
        theta_dyn = states[1]
        S_dyn = states[2]
        S_hat = states[3]

        theta = self.theta0 + theta_dyn

        # Apply dense weights
        current = tf.matmul(current, self.kernel)

        # Membrane filter
        if self.input_layer == True:  # in the case when the input to the neuron is already a current, e.g. pixel values
            I = current

        else:  # when the input is a spiking sequence
            I = I * self.dBeta + current

        # Membrane filter
        S_dyn = (1 - self.dPhi) * I + self.dPhi * S_dyn

        if self.use_bias:
            S = self.bias + S_dyn
        else:
            S = S_dyn

        # Decay
        S_hat = S_hat * self.dEta

        # Spike?
        spike = tf.cast(S - S_hat > 0.5 * theta, tf.float32)  # Code spike

        # Update refractory response
        S_hat = S_hat + tf.multiply(theta, spike)

        # Update threshold
        theta_dyn = theta_dyn + tf.multiply(tf.multiply(theta, spike), self.mf)

        # Decay
        theta_dyn = theta_dyn * self.dGamma

        if self.last_layer == True:
            out = self.activation(S * self.h)  # for the last layer give out the S instead of spikes
        else:
            out = spike * self.h  # if it is a spike scale by h

        return out, [I, theta_dyn, S_dyn, S_hat] 

Comparing these two approaches above over 300 timesteps shows that the Tensorflow implementation deals well with an increase in neurons computed over 300 timesteps, while the Numpy implementation quickly incurs long delays.

alt text

Please see the Demo for more details.

Conclusion

We here showed how any spiking neuron model can be used with current deep learning APIs for recurrent neural network routines. This is a simple and straightforward way to implement spiking neurons that can scale to large-scale architectures such as Resnet18.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages