From 3ffe7058afbd2fac186c3eb81095c99a18a6494c Mon Sep 17 00:00:00 2001 From: timcheck Date: Mon, 8 Jul 2024 12:58:12 -0700 Subject: [PATCH] Recurrent netx save and load (#324) * tutorial on netx save/load for recurrent network Signed-off-by: Jonathan Timcheck * replace the new step_delay function with delay(x,1), which worked previously for recurrent nets Signed-off-by: Jonathan Timcheck * netx save and load for RecurrentDense Signed-off-by: Jonathan Timcheck * removing recurrent_example.net Signed-off-by: Jonathan Timcheck * unit test recurrent NetX load/save Signed-off-by: Jonathan Timcheck * lint Signed-off-by: Jonathan Timcheck * lint --------- Signed-off-by: Jonathan Timcheck Co-authored-by: Marcus G K Williams <168222+mgkwill@users.noreply.github.com> --- src/lava/lib/dl/netx/blocks/process.py | 106 +++++++++++++++++++++++ src/lava/lib/dl/netx/hdf5.py | 41 +++++++-- src/lava/lib/dl/netx/utils.py | 3 +- src/lava/lib/dl/slayer/block/base.py | 69 +++++++++++++-- tests/lava/lib/dl/netx/test_hdf5.py | 112 ++++++++++++++++++++++++- 5 files changed, 313 insertions(+), 18 deletions(-) diff --git a/src/lava/lib/dl/netx/blocks/process.py b/src/lava/lib/dl/netx/blocks/process.py index 8a80a6b99..3eaece6e5 100644 --- a/src/lava/lib/dl/netx/blocks/process.py +++ b/src/lava/lib/dl/netx/blocks/process.py @@ -198,6 +198,112 @@ def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None: raise NotImplementedError +class RecurrentDense(AbstractBlock): + """RecurrentDense layer block. + Parameters + ---------- + shape : tuple or list + shape of the layer block in (x, y, z)/WHC format. + neuron_params : dict, optional + dictionary of neuron parameters. Defaults to None. + weight : np.ndarray + synaptic weight. + weight_rec : np.ndarray + recurrent synaptic weight. + delay : np.ndarray + synaptic delay. + bias : np.ndarray or None + bias of neuron. None means no bias. Defaults to None. + has_graded_input : dict + flag for graded spikes at input. Defaults to False. + num_weight_bits : int + number of weight bits. Defaults to 8. + weight_exponent : int + weight exponent value. Defaults to 0. + sparse_synapse : bool + connection is sparse + input_message_bits : int, optional + number of message bits in input spike. Defaults to 0 meaning unary + spike. + """ + + def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None: + super().__init__(**kwargs) + + weight = kwargs.pop('weight') + weight_rec = kwargs.pop('weight_rec') + delay = kwargs.pop('delay', None) + num_weight_bits = kwargs.pop('num_weight_bits', 8) + weight_exponent = kwargs.pop('weight_exponent', 0) + sparse_synapse = kwargs.pop('sparse_synapse', False) + + if delay is None: + if sparse_synapse: + Synapse = SparseSynapse + weight = csr_matrix(weight) + else: + Synapse = DenseSynapse + + self.synapse = Synapse( + weights=weight, + weight_exp=weight_exponent, + num_weight_bits=num_weight_bits, + num_message_bits=self.input_message_bits, + shape=weight.shape + ) + self.synapse_rec = Synapse( + weights=weight_rec, + weight_exp=weight_exponent, + num_weight_bits=num_weight_bits, + num_message_bits=self.input_message_bits, + shape=weight_rec.shape + ) + else: + # TODO test this in greater detail + if sparse_synapse: + Synapse = DelaySparseSynapse + delay[weight == 0] = 0 + weight = csr_matrix(weight) + delay = csr_matrix(delay) + else: + Synapse = DelayDenseSynapse + + self.synapse = Synapse( + weights=weight, + delays=delay.astype(int), + max_delay=62, + num_weight_bits=num_weight_bits, + num_message_bits=self.input_message_bits, + ) + self.synapse_rec = Synapse( + weights=weight_rec, + delays=delay.astype(int), + max_delay=62, + num_weight_bits=num_weight_bits, + num_message_bits=self.input_message_bits, + ) + + if self.shape != self.synapse.a_out.shape: + raise RuntimeError( + f'Expected synapse output shape to be {self.shape[-1]}, ' + f'found {self.synapse.a_out.shape}.' + ) + + self.neuron = self._neuron(kwargs.pop('bias', None)) + + self.inp = InPort(shape=self.synapse.s_in.shape) + self.out = OutPort(shape=self.neuron.s_out.shape) + self.inp.connect(self.synapse.s_in) + self.synapse.a_out.connect(self.neuron.a_in) + self.neuron.s_out.connect(self.out) + + self.neuron.s_out.connect(self.synapse_rec.s_in) + + self.synapse_rec.a_out.connect(self.neuron.a_in) + + self._clean() + + class ComplexDense(AbstractBlock): """Dense Complex layer block. diff --git a/src/lava/lib/dl/netx/hdf5.py b/src/lava/lib/dl/netx/hdf5.py index f3e57ea66..bd000f318 100644 --- a/src/lava/lib/dl/netx/hdf5.py +++ b/src/lava/lib/dl/netx/hdf5.py @@ -19,7 +19,8 @@ from lava.lib.dl.slayer.neuron.rf import neuron_params as get_rf_params from lava.lib.dl.netx.utils import NetDict from lava.lib.dl.netx.utils import optimize_weight_bits -from lava.lib.dl.netx.blocks.process import Input, Dense, Conv, ComplexDense +from lava.lib.dl.netx.blocks.process import Input, Dense, RecurrentDense, \ + Conv, ComplexDense from lava.lib.dl.netx.blocks.models import AbstractPyBlockModel @@ -316,7 +317,8 @@ def create_dense(layer_config: h5py.Group, reset_interval: Optional[int] = None, reset_offset: int = 0, spike_exp: int = 6, - sparse_synapse: bool = 0) -> Tuple[Dense, str]: + sparse_synapse: bool = 0, + rec: bool = False) -> Tuple[Dense, str]: """Creates dense layer from layer configuration Parameters @@ -389,11 +391,24 @@ def create_dense(layer_config: h5py.Group, else: weight = layer_config['weight'] + if rec: + weight_rec = layer_config['weight_rec'] if weight.ndim == 1: weight = weight.reshape(shape[0], -1) - - opt_weights = optimize_weight_bits(weight) - weight, num_weight_bits, weight_exponent, sign_mode = opt_weights + if rec: + weight_rec = weight_rec.reshape(shape[0], -1) + + if rec: + weight_concat = np.hstack((weight, weight_rec)) + opt_weight_concat = optimize_weight_bits(weight_concat) + weight_concat, num_weight_bits, \ + weight_exponent, sign_mode = opt_weight_concat + weight = weight_concat[:, :weight.shape[1]] + weight_rec = weight_concat[:, weight.shape[1]:] + else: + opt_weights = optimize_weight_bits(weight) + weight, num_weight_bits, \ + weight_exponent, sign_mode = opt_weights # arguments for dense block params = {'shape': shape, @@ -404,6 +419,8 @@ def create_dense(layer_config: h5py.Group, 'sign_mode': sign_mode, 'input_message_bits': input_message_bits, "sparse_synapse": sparse_synapse} + if rec: + params['weight_rec'] = weight_rec if 'delay' in layer_config.keys(): delay = layer_config['delay'] @@ -434,7 +451,10 @@ def create_dense(layer_config: h5py.Group, if 'bias' in layer_config.keys(): params['bias'] = layer_config['bias'] - proc = Dense(**params) + if rec: + proc = RecurrentDense(**params) + else: + proc = Dense(**params) table_entry = Network._table_str(type_str='Dense', width=1, height=1, channel=shape[0], delay='delay' in layer_config.keys()) @@ -615,14 +635,19 @@ def _create(self) -> List[AbstractProcess]: flatten_next = True table = None - elif layer_type == 'dense': + elif layer_type == 'dense' or layer_type == 'dense_rec': + if layer_type == 'dense': + rec = False + elif layer_type == 'dense_rec': + rec = True layer, table = self.create_dense( layer_config=layer_config[i], input_message_bits=input_message_bits, reset_interval=reset_interval, reset_offset=reset_offset, spike_exp=self.spike_exp, - sparse_synapse=self.sparse_fc_layer) + sparse_synapse=self.sparse_fc_layer, + rec=rec) if i >= self.skip_layers: layers.append(layer) reset_offset += 1 diff --git a/src/lava/lib/dl/netx/utils.py b/src/lava/lib/dl/netx/utils.py index 7f5219664..4502e29aa 100644 --- a/src/lava/lib/dl/netx/utils.py +++ b/src/lava/lib/dl/netx/utils.py @@ -47,7 +47,8 @@ def __init__( 'iDecay', 'refDelay', 'scaleRho', 'tauRho', 'theta', 'vDecay', 'vThMant', 'wgtExp', 'sinDecay', 'cosDecay' ] - self.copy_keys = ['weight', 'bias', 'weight/real', 'weight/imag'] + self.copy_keys = ['weight', 'weight_rec', + 'bias', 'weight/real', 'weight/imag'] def keys(self) -> h5py._hl.base.KeysViewHDF5: return self.f.keys() diff --git a/src/lava/lib/dl/slayer/block/base.py b/src/lava/lib/dl/slayer/block/base.py index 546f16007..2e71ffe67 100644 --- a/src/lava/lib/dl/slayer/block/base.py +++ b/src/lava/lib/dl/slayer/block/base.py @@ -1478,7 +1478,7 @@ def forward(self, x): self.spike_state = spike.clone().detach().reshape(z.shape[:-1]) if self.delay_shift is True: - x = step_delay(self, x) + x = delay(x, 1) if self.delay is not None: x = self.delay(x) @@ -1494,11 +1494,64 @@ def shape(self): return self.neuron.shape def export_hdf5(self, handle): - """Hdf5 export method for the block. + def weight(s): + return s.pre_hook_fx( + s.weight, descale=True + ).reshape(s.weight.shape[:2]).cpu().data.numpy() - Parameters - ---------- - handle : file handle - hdf5 handle to export block description. - """ - pass + def delay(d): + return torch.floor(d.delay).flatten().cpu().data.numpy() + + handle.create_dataset( + 'type', (1, ), 'S10', ['dense_rec'.encode('ascii', 'ignore')] + ) + + handle.create_dataset('shape', data=np.array(self.neuron.shape)) + handle.create_dataset( + 'inFeatures', data=self.input_synapse.in_channels) + handle.create_dataset( + 'outFeatures', data=self.input_synapse.out_channels) + + if self.input_synapse.weight_norm_enabled: + self.input_synapse.disable_weight_norm() + + if hasattr(self.input_synapse, 'imag'): # complex synapse + handle.create_dataset( + 'weight/real', + data=weight(self.input_synapse.real) + ) + handle.create_dataset( + 'weight/imag', + data=weight(self.input_synapse.imag) + ) + raise NotImplementedError(f'Complex recurrent not implemented.') + else: + handle.create_dataset('weight', data=weight(self.input_synapse)) + handle.create_dataset( + 'weight_rec', data=weight(self.recurrent_synapse)) + + # bias + has_norm = False + if hasattr(self.neuron, 'norm'): + if self.neuron.norm is not None: + has_norm = True + if has_norm is True: + handle.create_dataset( + 'bias', + data=self.neuron.norm.bias.cpu().data.numpy().flatten() + ) + + # delay + if self.delay is not None: + self.delay.clamp() # clamp the delay value + handle.create_dataset('delay', data=delay(self.delay)) + + # neuron + for key, value in self.neuron.device_params.items(): + handle.create_dataset(f'neuron/{key}', data=value) + if has_norm is True: + if hasattr(self.neuron.norm, 'weight_exp'): + handle.create_dataset( + 'neuron/weight_exp', + data=self.neuron.norm.weight_exp + ) diff --git a/tests/lava/lib/dl/netx/test_hdf5.py b/tests/lava/lib/dl/netx/test_hdf5.py index 28a6c212a..1b6036ee6 100644 --- a/tests/lava/lib/dl/netx/test_hdf5.py +++ b/tests/lava/lib/dl/netx/test_hdf5.py @@ -19,7 +19,11 @@ from lava.proc.dense.process import Dense, DelayDense from lava.lib.dl import netx - +import torch +from lava.lib.dl import slayer +import h5py +from lava.magma.core.run_configs import Loihi2SimCfg +from lava.magma.core.run_conditions import RunSteps verbose = True if (('-v' in sys.argv) or ('--verbose' in sys.argv)) else False HAVE_DISPLAY = 'DISPLAY' in os.environ @@ -287,6 +291,112 @@ def test_sparse_axonal_delay_ntidigits(self) -> None: ]) ) + def test_recurrent_netx_save_and_load(self) -> None: + """Tests that recurrent network can be saved and loaded. + Verifies that an toy recurrent network runs in Lava and matches + Lava-dl output firing rates within 2% error margin.""" + + class Network(torch.nn.Module): + def __init__( + self, + ): + super(Network, self).__init__() + + self.cuba_params = { + "threshold": 0.95, + "current_decay": 0.5, + "voltage_decay": 0.15, + "tau_grad": 1.0, + "scale_grad": 1.0, + "shared_param": False, + "requires_grad": False, + "graded_spike": False, + } + self.blocks = torch.nn.ModuleList( + [ + slayer.block.cuba.Recurrent( + self.cuba_params, + 6, + 5, + weight_scale=0.0, + pre_hook_fx=None, + ), + slayer.block.cuba.Recurrent( + self.cuba_params, + 5, + 5, + weight_scale=0.0, + pre_hook_fx=None, + ), + ] + ) + + self.blocks[0].input_synapse.weight.data += 0.1 + self.blocks[0].recurrent_synapse.weight.data += 0.05 + + self.blocks[1].input_synapse.weight.data += 0.1 + self.blocks[1].recurrent_synapse.weight.data += 0.05 + + def forward(self, spike): + for block in self.blocks: + spike = block(spike) + return spike + + def export_hdf5(self, filename): + # network export to hdf5 format + with h5py.File(filename, "w") as h: + layer = h.create_group("layer") + for i, b in enumerate(self.blocks): + b.export_hdf5(layer.create_group(f"{i}")) + + current_file_directory = os.path.dirname(os.path.abspath(__file__)) + + # batch = 1, channels = 6, time = 101 + num_steps = 101 + input = torch.zeros((1, 6, num_steps)) + input += 0.2 + + net = Network() + lava_dl_output = net(input).detach().numpy() + + filename = os.path.join(current_file_directory, "recurrent_example.net") + net.export_hdf5(filename) + + net_lava = netx.hdf5.Network(filename, input_message_bits=24) + + net_lava_input_scale_factor_exp = 10 + net_lava_input_scale_factor = 2**net_lava_input_scale_factor_exp + net_lava.layers[0].synapse_rec.proc_params._parameters[ + "weight_exp" + ] += net_lava_input_scale_factor_exp + net_lava.layers[0].neuron.vth.init *= net_lava_input_scale_factor + net_lava.layers[0].neuron.bias_mant.init *= net_lava_input_scale_factor + + input_lava = input[0].numpy() * net_lava_input_scale_factor + + source = io.source.RingBuffer(data=input_lava) + sink = io.sink.RingBuffer( + shape=net_lava.out.shape, buffer=num_steps + 1 + ) + source.s_out.connect(net_lava.inp) + net_lava.out.connect(sink.a_in) + + run_config = Loihi2SimCfg(select_tag="fixed_pt") + run_condition = RunSteps(num_steps=num_steps) + net_lava.run(condition=run_condition, run_cfg=run_config) + lava_output = sink.data.get() + net_lava.stop() + + eps = 0.02 + assert ( + np.abs( + (lava_dl_output[0][0].mean() - lava_output[0].mean()) + / lava_dl_output[0][0].mean() + ) + <= eps + ), f"""Mean firing rate mismatch {lava_dl_output[0][0].mean()=}; + lava mean firing rate {lava_output[0].mean()=}""" + if __name__ == '__main__': unittest.main()