Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vivado Backend GRU/LSTM support (PR560) #576

Merged
merged 10 commits into from
Jun 21, 2022
171 changes: 171 additions & 0 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@

from hls4ml.backends.backend import get_backend
from hls4ml.model.layers import LSTM, GRU
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate

# recurrent multiplication template

recr_mult_config_template = """struct config{index} : nnet::dense_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned strategy = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned n_nonzeros = {nonzeros};
static const bool store_weights_in_bram = false;
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef ap_{index_t} index_t;
template<class x_T, class y_T, class res_T>
using product = nnet::product::{product_type}<x_T, y_T, res_T>;
}};\n"""

#activation templates

activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef ap_{table_t} table_t;
}};\n"""

recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef ap_{table_t} table_t;
}};\n"""

# LSTM + GRU templates

recr_config_template = """struct config{index} : nnet::{recr_type}_config {{
typedef {accum_t.name} accum_t;
typedef {weight_t.name} weight_t; // Matrix
typedef {bias_t.name} bias_t; // Vector
typedef {config_mult_t1} mult_config1;
typedef {config_mult_t2} mult_config2;
typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE};
template<class x_T, class y_T, class config_T>
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
typedef {act_t} ACT_CONFIG_T;
template<class x_T, class y_T, class config_T>
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned n_state = {n_state};
static const unsigned n_sequence = {n_sequence};
static const unsigned n_sequence_out = {n_sequence_out};
static const unsigned io_type = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static const bool use_static = {static};
}};\n"""

recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'

recr_include_list = ['nnet_utils/nnet_recurrent.h']

class RecurrentConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__((LSTM, GRU))
self.template = recr_config_template
self.act_template = activ_config_template
self.recr_act_template = recr_activ_config_template
self.mult1_template = recr_mult_config_template
self.mult2_template = recr_mult_config_template

def format(self, node):

params = self._default_config_params(node)

params['n_in'] = node.get_input_variable().dim_names[1]
params['n_sequence'] = node.get_input_variable().dim_names[0]
if node.get_attr('return_sequences'):
params['n_sequence_out'] = node.get_output_variable().dim_names[0]
params['n_state'] = node.get_output_variable().dim_names[1]
params['n_out'] = node.get_output_variable().dim_names[1]
else:
params['n_sequence_out'] = 1
params['n_state'] = node.get_output_variable().dim_names[0]
params['n_out'] = node.get_output_variable().dim_names[0]
params['config_mult_t1'] = 'config{}_1'.format(node.index)
params['config_mult_t2'] = 'config{}_2'.format(node.index)
params['recr_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index)
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
params['strategy'] = node.get_attr('strategy')
params['static'] = 'true' if node.attributes['static'] else 'false'
params['recr_type'] = node.class_name.lower()
params['RECR_TYPE'] = node.class_name

if node.class_name=='LSTM':
n_recr_mult = 4
else: #GRU
n_recr_mult = 3

recr_config = self.template.format(**params)

act_params = self._default_config_params(node)
recr_act_params = self._default_config_params(node)

act_params['type'] = node.get_attr('activation')
recr_act_params['type'] = node.get_attr('recurrent_activation')
if node.get_attr('return_sequences'):
act_params['n_in'] = node.get_output_variable().dim_names[1]
recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * %i'%(n_recr_mult-1)
else:
act_params['n_in'] = node.get_output_variable().dim_names[0]
recr_act_params['n_in'] = node.get_output_variable().dim_names[0] + ' * %i'%(n_recr_mult-1)

act_config = self.act_template.format(**act_params)
recr_act_config = self.recr_act_template.format(**recr_act_params)

mult_params1 = self._default_config_params(node)
mult_params2 = self._default_config_params(node)

mult_params1['n_in'] = node.get_input_variable().dim_names[1]
if node.get_attr('return_sequences'):
mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult
else:
mult_params1['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult
mult_params1['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
mult_params1['reuse'] = params['reuse']
mult_params1['index'] = str(node.index) + '_1'
mult_params1['nzeros'] = node.get_weights('weight').nzeros
mult_params1['nonzeros'] = node.get_weights('weight').nonzeros
if node.get_attr('return_sequences'):
mult_params2['n_in'] = node.get_output_variable().dim_names[1]
mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult
else:
mult_params2['n_in'] = node.get_output_variable().dim_names[0]
mult_params2['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult
mult_params2['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision)
mult_params2['reuse'] = node.attributes['recurrent_reuse_factor']
mult_params2['index'] = str(node.index) + '_2'
mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros
mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros

mult_config1 = self.mult1_template.format(**mult_params1)
mult_config2 = self.mult2_template.format(**mult_params2)

return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config

class RecurrentFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((LSTM, GRU), include_header=recr_include_list)
self.template = recr_function_template

def format(self, node):
params = self._default_function_params(node)
params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
params['br'] = node.get_weights('recurrent_bias').name
params['activation'] = node.get_attr('activation')
params['recurrent_activation'] = node.get_attr('recurrent_activation')
params['recr_type'] = node.class_name.lower()

return self.template.format(**params)

9 changes: 6 additions & 3 deletions hls4ml/backends/vivado/passes/resource_strategy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np

from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D
from hls4ml.model.layers import Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D, LSTM, GRU

class ApplyResourceStrategy(OptimizerPass):
''' Transposes the weights to use the dense_resource matrix multiply routine '''
def match(self, node):

node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU))
is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource'
already_transformed = node.get_attr('_weights_transposed', False) == True

Expand All @@ -26,9 +26,12 @@ def transform(self, model, node):
elif isinstance(node, SeparableConv2D):
node.weights['depthwise'].data = np.transpose(node.weights['depthwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
node.weights['pointwise'].data = np.transpose(node.weights['pointwise'].data, axes=[3, 0, 1, 2]) #(H,W,C,F) => (F,H,W,C)
elif isinstance(node, (LSTM, GRU)):
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
else:
raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name))

node.set_attr('_weights_transposed', True)

return False
return False
6 changes: 3 additions & 3 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self):

def _register_layer_attributes(self):
extended_attrs = {
SimpleRNN: [Attribute('recurrent_reuse_factor', default=1)],
LSTM: [Attribute('recurrent_reuse_factor', default=1)],
GRU: [Attribute('recurrent_reuse_factor', default=1)],
SimpleRNN: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)],
LSTM: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)],
GRU: [Attribute('recurrent_reuse_factor', default=1), Attribute('static', value_type=bool, default=True)],
}
self.attribute_map.update(extended_attrs)

Expand Down
5 changes: 2 additions & 3 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,10 +863,9 @@ class SimpleRNN(Layer):
def initialize(self):
if self.attributes['return_sequences']:
shape = [self.attributes['n_timesteps'], self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]
else:
shape = [self.attributes['n_out']]
dims = ['N_OUT_{}'.format(self.index)]
shape = [1, self.attributes['n_out']]
dims = ['N_TIME_STEPS_{}'.format(self.index), 'N_OUT_{}'.format(self.index)]

self.add_output_variable(shape, dims)

Expand Down
5 changes: 4 additions & 1 deletion hls4ml/model/optimizer/passes/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
class ReplaceMultidimensionalDenseWithConv(OptimizerPass):
def match(self, node):
return isinstance(node, Dense) and \
len(node.get_input_variable().shape) > 1
len(node.get_input_variable().shape) - sum(d==1 for d in node.get_input_variable().shape) > 1
# The above sum checks for the number of dimensions in the Dense with size 1
# The subtraction allows the check to only count the number of dimensions with non-1 size
# For example, this prevents matching for a Dense layer with shape (1,N)

def transform(self, model, node):
dim = len(node.get_input_variable().shape) - 1
Expand Down
63 changes: 63 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_recr_activations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef NNET_RECR_ACTIVATION_H_
#define NNET_RECR_ACTIVATION_H_

#include "nnet_common.h"
#include "nnet_helpers.h"
#include "nnet_activation.h"
#include "hls_stream.h"
#include <math.h>

namespace nnet {

namespace activation{

template<class data_T, class res_T, typename CONFIG_T>
class Activation{
public:
// *************************************************
// Blank Activation
// *************************************************
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {} // Nothing to do here
};

template<class data_T, class res_T, typename CONFIG_T>
class relu : public Activation<data_T, res_T, CONFIG_T>{
public:
// *************************************************
// Relu Activation
// *************************************************
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
nnet::relu<data_T, res_T, CONFIG_T>(data, res);
}
};

template<class data_T, class res_T, typename CONFIG_T>
class sigmoid : public Activation<data_T, res_T, CONFIG_T>{
public:
// *************************************************
// Sigmoid Activation
// *************************************************
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
nnet::sigmoid<data_T, res_T, CONFIG_T>(data, res);
}
};

template<class data_T, class res_T, typename CONFIG_T>
class tanh : public Activation<data_T, res_T, CONFIG_T>{
public:
// *************************************************
// TanH Activation
// *************************************************
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
{
nnet::tanh<data_T, res_T, CONFIG_T>(data, res);
}
};

}

}

#endif
Loading