Skip to content

Commit

Permalink
cleaning up transpose for resource strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
drankincms committed Jun 8, 2022
1 parent 03823c4 commit cb76f4e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
9 changes: 6 additions & 3 deletions hls4ml/backends/vivado/passes/resource_strategy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np

from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D
from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU

class ApplyResourceStrategy(OptimizerPass):
''' Transposes the weights to use the dense_resource matrix multiply routine '''
def match(self, node):

node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU))
is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource'
already_transformed = node.get_attr('_weights_transposed', False) == True

Expand All @@ -26,9 +26,12 @@ def transform(self, model, node):
elif isinstance(node, SeparableConv2D):
node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
elif isinstance(node, (LSTM, GRU)):
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
else:
raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name))

node.set_attr('_weights_transposed', True)

return False
return False
4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def init_lstm(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[0 if layer.model.config.is_resource_strategy(layer) else 1])
recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)
Expand All @@ -267,7 +267,7 @@ def init_gru(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[0 if layer.model.config.is_resource_strategy(layer) else 1])
recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)
Expand Down
14 changes: 2 additions & 12 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,15 +916,10 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

data = self.model.get_weights_data(self.name, 'kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
data = np.transpose(data)
self.add_weights_variable(name='weight', var_name='w{index}', data=data)
self.add_weights()
self.add_bias()

recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
recurrent_weight = np.transpose(recurrent_weight)
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

class GRU(Layer):
Expand Down Expand Up @@ -963,15 +958,10 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

data = self.model.get_weights_data(self.name, 'kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
data = np.transpose(data)
self.add_weights_variable(name='weight', var_name='w{index}', data=data)
self.add_weights()
self.add_bias()

recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
recurrent_weight = np.transpose(recurrent_weight)
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

class GarNet(Layer):
Expand Down

0 comments on commit cb76f4e

Please sign in to comment.