Skip to content

Commit

Permalink
netx save and load for RecurrentDense
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Timcheck <jonathan.timcheck@intel.com>
  • Loading branch information
timcheck committed Jun 12, 2024
1 parent 543a95c commit 96b10a9
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 16 deletions.
104 changes: 104 additions & 0 deletions src/lava/lib/dl/netx/blocks/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,111 @@ def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None:

def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None:
raise NotImplementedError

class RecurrentDense(AbstractBlock):
"""RecurrentDense layer block.
Parameters
----------
shape : tuple or list
shape of the layer block in (x, y, z)/WHC format.
neuron_params : dict, optional
dictionary of neuron parameters. Defaults to None.
weight : np.ndarray
synaptic weight.
weight_rec : np.ndarray
recurrent synaptic weight.
delay : np.ndarray
synaptic delay.
bias : np.ndarray or None
bias of neuron. None means no bias. Defaults to None.
has_graded_input : dict
flag for graded spikes at input. Defaults to False.
num_weight_bits : int
number of weight bits. Defaults to 8.
weight_exponent : int
weight exponent value. Defaults to 0.
sparse_synapse : bool
connection is sparse
input_message_bits : int, optional
number of message bits in input spike. Defaults to 0 meaning unary
spike.
"""

def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None:
super().__init__(**kwargs)

weight = kwargs.pop('weight')
weight_rec = kwargs.pop('weight_rec')
delay = kwargs.pop('delay', None)
num_weight_bits = kwargs.pop('num_weight_bits', 8)
weight_exponent = kwargs.pop('weight_exponent', 0)
sparse_synapse = kwargs.pop('sparse_synapse', False)

if delay is None:
if sparse_synapse:
Synapse = SparseSynapse
weight = csr_matrix(weight)
else:
Synapse = DenseSynapse

self.synapse = Synapse(
weights=weight,
weight_exp=weight_exponent,
num_weight_bits=num_weight_bits,
num_message_bits=self.input_message_bits,
shape=weight.shape
)
self.synapse_rec = Synapse(
weights=weight_rec,
weight_exp=weight_exponent,
num_weight_bits=num_weight_bits,
num_message_bits=self.input_message_bits,
shape=weight_rec.shape
)
else:
# TODO test this in greater detail
if sparse_synapse:
Synapse = DelaySparseSynapse
delay[weight == 0] = 0
weight = csr_matrix(weight)
delay = csr_matrix(delay)
else:
Synapse = DelayDenseSynapse

self.synapse = Synapse(
weights=weight,
delays=delay.astype(int),
max_delay=62,
num_weight_bits=num_weight_bits,
num_message_bits=self.input_message_bits,
)
self.synapse_rec = Synapse(
weights=weight_rec,
delays=delay.astype(int),
max_delay=62,
num_weight_bits=num_weight_bits,
num_message_bits=self.input_message_bits,
)

if self.shape != self.synapse.a_out.shape:
raise RuntimeError(
f'Expected synapse output shape to be {self.shape[-1]}, '
f'found {self.synapse.a_out.shape}.'
)

self.neuron = self._neuron(kwargs.pop('bias', None))

self.inp = InPort(shape=self.synapse.s_in.shape)
self.out = OutPort(shape=self.neuron.s_out.shape)
self.inp.connect(self.synapse.s_in)
self.synapse.a_out.connect(self.neuron.a_in)
self.neuron.s_out.connect(self.out)

self.neuron.s_out.connect(self.synapse_rec.s_in)

self.synapse_rec.a_out.connect(self.neuron.a_in)

self._clean()

class ComplexDense(AbstractBlock):
"""Dense Complex layer block.
Expand Down
38 changes: 30 additions & 8 deletions src/lava/lib/dl/netx/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lava.lib.dl.slayer.neuron.rf import neuron_params as get_rf_params
from lava.lib.dl.netx.utils import NetDict
from lava.lib.dl.netx.utils import optimize_weight_bits
from lava.lib.dl.netx.blocks.process import Input, Dense, Conv, ComplexDense
from lava.lib.dl.netx.blocks.process import Input, Dense, RecurrentDense, Conv, ComplexDense
from lava.lib.dl.netx.blocks.models import AbstractPyBlockModel


Expand Down Expand Up @@ -316,7 +316,8 @@ def create_dense(layer_config: h5py.Group,
reset_interval: Optional[int] = None,
reset_offset: int = 0,
spike_exp: int = 6,
sparse_synapse: bool = 0) -> Tuple[Dense, str]:
sparse_synapse: bool = 0,
rec: bool = False) -> Tuple[Dense, str]:
"""Creates dense layer from layer configuration
Parameters
Expand Down Expand Up @@ -389,11 +390,22 @@ def create_dense(layer_config: h5py.Group,

else:
weight = layer_config['weight']
if rec:
weight_rec = layer_config['weight_rec']
if weight.ndim == 1:
weight = weight.reshape(shape[0], -1)

opt_weights = optimize_weight_bits(weight)
weight, num_weight_bits, weight_exponent, sign_mode = opt_weights
if rec:
weight_rec = weight_rec.reshape(shape[0], -1)

if rec:
weight_concat = np.hstack((weight,weight_rec))
opt_weight_concat = optimize_weight_bits(weight_concat)
weight_concat, num_weight_bits, weight_exponent, sign_mode = opt_weight_concat
weight = weight_concat[:,:weight.shape[1]]
weight_rec = weight_concat[:,weight.shape[1]:]
else:
opt_weights = optimize_weight_bits(weight)
weight, num_weight_bits, weight_exponent, sign_mode = opt_weights

# arguments for dense block
params = {'shape': shape,
Expand All @@ -404,6 +416,8 @@ def create_dense(layer_config: h5py.Group,
'sign_mode': sign_mode,
'input_message_bits': input_message_bits,
"sparse_synapse": sparse_synapse}
if rec:
params['weight_rec'] = weight_rec

if 'delay' in layer_config.keys():
delay = layer_config['delay']
Expand Down Expand Up @@ -434,7 +448,10 @@ def create_dense(layer_config: h5py.Group,
if 'bias' in layer_config.keys():
params['bias'] = layer_config['bias']

proc = Dense(**params)
if rec:
proc = RecurrentDense(**params)
else:
proc = Dense(**params)
table_entry = Network._table_str(type_str='Dense', width=1, height=1,
channel=shape[0],
delay='delay' in layer_config.keys())
Expand Down Expand Up @@ -615,14 +632,19 @@ def _create(self) -> List[AbstractProcess]:
flatten_next = True
table = None

elif layer_type == 'dense':
elif layer_type == 'dense' or layer_type == 'dense_rec':
if layer_type == 'dense':
rec = False
elif layer_type == 'dense_rec':
rec = True
layer, table = self.create_dense(
layer_config=layer_config[i],
input_message_bits=input_message_bits,
reset_interval=reset_interval,
reset_offset=reset_offset,
spike_exp=self.spike_exp,
sparse_synapse=self.sparse_fc_layer)
sparse_synapse=self.sparse_fc_layer,
rec=rec)
if i >= self.skip_layers:
layers.append(layer)
reset_offset += 1
Expand Down
2 changes: 1 addition & 1 deletion src/lava/lib/dl/netx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
'iDecay', 'refDelay', 'scaleRho', 'tauRho', 'theta', 'vDecay',
'vThMant', 'wgtExp', 'sinDecay', 'cosDecay'
]
self.copy_keys = ['weight', 'bias', 'weight/real', 'weight/imag']
self.copy_keys = ['weight', 'weight_rec', 'bias', 'weight/real', 'weight/imag']

def keys(self) -> h5py._hl.base.KeysViewHDF5:
return self.f.keys()
Expand Down
64 changes: 57 additions & 7 deletions src/lava/lib/dl/slayer/block/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,11 +1494,61 @@ def shape(self):
return self.neuron.shape

def export_hdf5(self, handle):
"""Hdf5 export method for the block.
def weight(s):
return s.pre_hook_fx(
s.weight, descale=True
).reshape(s.weight.shape[:2]).cpu().data.numpy()

Parameters
----------
handle : file handle
hdf5 handle to export block description.
"""
pass
def delay(d):
return torch.floor(d.delay).flatten().cpu().data.numpy()

handle.create_dataset(
'type', (1, ), 'S10', ['dense_rec'.encode('ascii', 'ignore')]
)

handle.create_dataset('shape', data=np.array(self.neuron.shape))
handle.create_dataset('inFeatures', data=self.input_synapse.in_channels)
handle.create_dataset('outFeatures', data=self.input_synapse.out_channels)

if self.input_synapse.weight_norm_enabled:
self.input_synapse.disable_weight_norm()

if hasattr(self.input_synapse, 'imag'): # complex synapse
handle.create_dataset(
'weight/real',
data=weight(self.input_synapse.real)
)
handle.create_dataset(
'weight/imag',
data=weight(self.input_synapse.imag)
)
raise NotImplementedError(f'Complex recurrent not implemented.')
else:
handle.create_dataset('weight', data=weight(self.input_synapse))
handle.create_dataset('weight_rec', data=weight(self.recurrent_synapse))

# bias
has_norm = False
if hasattr(self.neuron, 'norm'):
if self.neuron.norm is not None:
has_norm = True
if has_norm is True:
handle.create_dataset(
'bias',
data=self.neuron.norm.bias.cpu().data.numpy().flatten()
)

# delay
if self.delay is not None:
self.delay.clamp() # clamp the delay value
handle.create_dataset('delay', data=delay(self.delay))

# neuron
for key, value in self.neuron.device_params.items():
handle.create_dataset(f'neuron/{key}', data=value)
if has_norm is True:
if hasattr(self.neuron.norm, 'weight_exp'):
handle.create_dataset(
'neuron/weight_exp',
data=self.neuron.norm.weight_exp
)

0 comments on commit 96b10a9

Please sign in to comment.