Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN support (part 1) #521

Merged
merged 5 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,48 @@ def write(self, model):
def get_writer_flow(self):
raise NotImplementedError

def get_valid_reuse_factors(self, layer):
n_in = 0
n_out = 0
def get_layer_mult_size(self, layer):
if 'Dense' in layer.class_name:
n_in = layer.get_attr('n_in')
n_out = layer.get_attr('n_out')
elif 'Conv1D' in layer.class_name:
return n_in, n_out

if 'Conv1D' in layer.class_name:
n_in = layer.get_attr('n_chan') * layer.get_attr('filt_width')
n_out = layer.get_attr('n_filt')
elif 'Conv2D' in layer.class_name:
return n_in, n_out

if 'Conv2D' in layer.class_name:
n_in = layer.get_attr('n_chan') * layer.get_attr('filt_height') * layer.get_attr('filt_width')
n_out = layer.get_attr('n_filt')
return n_in, n_out

if 'LSTM' in layer.class_name:
n_in = layer.get_attr('n_in')
n_out = layer.get_attr('n_out') * 4
n_in_recr = layer.get_attr('n_out')
n_out_recr = n_out
return n_in, n_out, n_in_recr, n_out_recr

if 'GRU' in layer.class_name:
n_in = layer.get_attr('n_in')
n_out = layer.get_attr('n_out') * 3
n_in_recr = layer.get_attr('n_out')
n_out_recr = n_out
return n_in, n_out, n_in_recr, n_out_recr

raise Exception(f'Cannot get mult size for layer {layer.name} ({layer.class_name})')

def get_valid_reuse_factors(self, n_in, n_out):
max_rf = n_in * n_out
valid_reuse_factors = []
for rf in range(1, max_rf + 1):
_assert = self._check_conditions(n_in, n_out, rf)
_assert = self._validate_reuse_factor(n_in, n_out, rf)
if _assert:
valid_reuse_factors.append(rf)
return valid_reuse_factors

def _check_conditions(self, n_in, n_out, rf):
def _validate_reuse_factor(self, n_in, n_out, rf):
multfactor = min(n_in, rf)
multiplier_limit = int(math.ceil((n_in * n_out) / float(multfactor)))
#
Expand Down Expand Up @@ -112,16 +132,19 @@ def get_closest_reuse_factor(self, valid_rf, chosen_rf):
else:
return before

def set_closest_reuse_factor(self, layer):
valid_rf = self.get_valid_reuse_factors(layer)
chosen_rf = layer.get_attr('reuse_factor')
def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor'):
assert attribute is not None, 'Reuse factor attribute cannot be None'

valid_rf = self.get_valid_reuse_factors(n_in, n_out)
chosen_rf = layer.get_attr(attribute)
if chosen_rf not in valid_rf:
closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf)
print('WARNING: Invalid ReuseFactor={} in layer "{}". Using ReuseFactor={} instead. Valid ReuseFactor(s): {}.'
.format(chosen_rf, layer.name, closest_rf, ','.join(map(str, valid_rf))))
layer.set_attr('reuse_factor', closest_rf)
layer.set_attr(attribute, closest_rf)

def set_target_reuse_factor(self, layer):
# TODO update target reuse factor for the RNN layers
targ_cycles = layer.get_attr('target_cycles')

shuffle_cycles = 6 # Number of clock cycles to move data around
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def init_dense(self, layer):
if layer.model.config.get_compression(layer):
layer.set_attr('strategy', 'compressed')
else:
self.set_closest_reuse_factor(layer)
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
self.gen_quartus_weight_array(layer)
layer.set_attr('strategy', 'resource')

Expand Down
38 changes: 37 additions & 1 deletion hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from hls4ml.backends.backend import get_backend
from hls4ml.model.layers import Activation, BatchNormalization, Dense, PReLU, ParametrizedActivation, Softmax
from hls4ml.model.layers import Activation, BatchNormalization, Dense, Embedding, PReLU, ParametrizedActivation, Softmax
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate

# Dense templates
Expand Down Expand Up @@ -174,3 +174,39 @@ def format(self, node):
params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index)

return self.template.format(**params)


# Embedding templates

embed_config_template = """struct config{index} : nnet::embed_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned vocab_size = {vocab_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
typedef {embeddings_t.name} embeddings_t;
}};\n"""

embed_function_template = 'nnet::embedding<{input_t}, {output_t}, {config}>({input}, {output}, {e});'

embed_include_list = ['nnet_utils/nnet_embed.h', 'nnet_utils/nnet_embed_stream.h']

class EmbeddingConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Embedding)
self.template = embed_config_template

def format(self, node):
params = self._default_config_params(node)
return self.template.format(**params)

class EmbeddingFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Embedding, include_header=embed_include_list)
self.template = embed_function_template

def format(self, node):
params = self._default_function_params(node)
params['e'] = node.get_weights('embeddings').name

return self.template.format(**params)
87 changes: 80 additions & 7 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Iterable

from hls4ml.model.types import FixedPrecisionType, NamedType, IntegerPrecisionType
from hls4ml.model.layers import Layer, Dense, BatchNormalization, Conv1D, Conv2D, Conv2DBatchnorm, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Activation, ParametrizedActivation, PReLU, Softmax, Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D, ZeroPadding1D, ZeroPadding2D, Merge, Concatenate, Dot, Resize, Transpose, GarNet, GarNetStack
from hls4ml.model.layers import Layer, Dense, BatchNormalization, Embedding, Conv1D, Conv2D, Conv2DBatchnorm, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Activation, ParametrizedActivation, PReLU, Softmax, Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D, ZeroPadding1D, ZeroPadding2D, Merge, Concatenate, Dot, Resize, Transpose, SimpleRNN, LSTM, GRU, GarNet, GarNetStack
from hls4ml.model.attributes import Attribute
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer, model_optimizer
from hls4ml.model.flow import register_flow
Expand All @@ -18,8 +18,17 @@
class VivadoBackend(FPGABackend):
def __init__(self):
super(VivadoBackend, self).__init__('Vivado')
self._register_layer_attributes()
self._register_flows()

def _register_layer_attributes(self):
extended_attrs = {
SimpleRNN: [Attribute('recurrent_reuse_factor', default=1)],
LSTM: [Attribute('recurrent_reuse_factor', default=1)],
GRU: [Attribute('recurrent_reuse_factor', default=1)],
}
self.attribute_map.update(extended_attrs)

def _register_flows(self):
initializers = self._get_layer_initializers()
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
Expand Down Expand Up @@ -123,8 +132,9 @@ def init_dense(self, layer):
index_t = IntegerPrecisionType(width=1, signed=False)
compression = layer.model.config.get_compression(layer)
if layer.model.config.is_resource_strategy(layer):
n_in, n_out = self.get_layer_mult_size(layer)
self.set_target_reuse_factor(layer)
self.set_closest_reuse_factor(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
if compression:
layer.set_attr('strategy', 'compressed')
index_t = layer.get_weights('weight').type.index_precision
Expand All @@ -142,8 +152,9 @@ def init_conv1d(self, layer):

if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
n_in, n_out = self.get_layer_mult_size(layer)
self.set_target_reuse_factor(layer)
self.set_closest_reuse_factor(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

Expand All @@ -153,7 +164,8 @@ def init_conv1d(self, layer):
def init_sepconv1d(self, layer):
if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
self.set_closest_reuse_factor(layer)
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

Expand All @@ -167,7 +179,8 @@ def init_conv2d(self, layer):
if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
self.set_target_reuse_factor(layer)
self.set_closest_reuse_factor(layer)
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

Expand All @@ -177,7 +190,8 @@ def init_conv2d(self, layer):
def init_sepconv2d(self, layer):
if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
self.set_closest_reuse_factor(layer)
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

Expand All @@ -187,7 +201,8 @@ def init_sepconv2d(self, layer):
def init_depconv2d(self, layer):
if layer.model.config.is_resource_strategy(layer):
layer.set_attr('strategy', 'resource')
self.set_closest_reuse_factor(layer)
n_in, n_out = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
else:
layer.set_attr('strategy', 'latency')

Expand Down Expand Up @@ -215,6 +230,64 @@ def init_softmax(self, layer):
if layer.model.config.get_config_value('IOType') == 'io_parallel':
assert len(layer.get_input_variable().shape) == 1, 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.'

@layer_optimizer(Embedding)
def init_embed(self, layer):
if layer.attributes['n_in'] is None:
raise Exception('Input length of Embedding layer must be specified.')

@layer_optimizer(LSTM)
def init_lstm(self, layer):
# TODO Allow getting recurrent reuse factor from the config
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)

if 'table_t' not in layer.attributes:
layer.set_attr('table_t', FixedPrecisionType(width=18, integer=8))
if 'table_size' not in layer.attributes:
layer.set_attr('table_size', 1024)
if layer.model.config.is_resource_strategy(layer):
n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor')
layer.weights['weight'].data = np.transpose(layer.weights['weight'].data)
layer.weights['recurrent_weight'].data = np.transpose(layer.weights['recurrent_weight'].data)
layer.set_attr('strategy', 'resource')
else:
layer.set_attr('strategy', 'latency')

layer.set_attr('index_t', index_t)

@layer_optimizer(GRU)
def init_gru(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)

if 'table_t' not in layer.attributes:
layer.set_attr('table_t', FixedPrecisionType(width=18, integer=8))
if 'table_size' not in layer.attributes:
layer.set_attr('table_size', 1024)
if layer.model.config.is_resource_strategy(layer):
n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)
self.set_closest_reuse_factor(layer, n_in, n_out)
self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor')
layer.weights['weight'].data = np.transpose(layer.weights['weight'].data)
layer.weights['recurrent_weight'].data = np.transpose(layer.weights['recurrent_weight'].data)
layer.set_attr('strategy', 'resource')
else:
layer.set_attr('strategy', 'latency')

layer.set_attr('index_t', index_t)

@layer_optimizer(GarNet)
def init_garnet(self, layer):
reuse_factor = layer.attributes['reuse_factor']
Expand Down
15 changes: 15 additions & 0 deletions hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,18 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader, c
layer['n_filt']=input_shapes[0][3]

return layer, [shape for shape in input_shapes[0]]


@keras_handler('Embedding')
def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert('Embedding' in keras_layer['class_name'])

layer = parse_default_keras_layer(keras_layer, input_names)

layer['n_in'] = input_shapes[0][1]
layer['vocab_size'] = keras_layer['config']['input_dim']
layer['n_out'] = keras_layer['config']['output_dim']

output_shape = input_shapes[0] + [layer['n_out']]

return layer, output_shape
44 changes: 44 additions & 0 deletions hls4ml/converters/keras/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np

from hls4ml.converters.keras_to_hls import parse_default_keras_layer
from hls4ml.converters.keras_to_hls import keras_handler

from hls4ml.model.types import Quantizer
from hls4ml.model.types import IntegerPrecisionType

rnn_layers = ['SimpleRNN', 'LSTM', 'GRU']
@keras_handler(*rnn_layers)
def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert(keras_layer['class_name'] in rnn_layers)

layer = parse_default_keras_layer(keras_layer, input_names)

layer['return_sequences'] = keras_layer['config']['return_sequences']
layer['return_state'] = keras_layer['config']['return_state']

if layer['class_name'] != 'SimpleRNN':
layer['recurrent_activation'] = keras_layer['config']['recurrent_activation']

layer['time_major'] = keras_layer['config']['time_major'] if 'time_major' in keras_layer['config'] else False

# TODO Should we handle time_major?
if layer['time_major']:
raise Exception('Time-major format is not supported by hls4ml'.format(layer['class_name']))

layer['n_timesteps'] = input_shapes[0][1]
layer['n_in'] = input_shapes[0][2]

layer['n_out'] = keras_layer['config']['units']

if layer['class_name'] == 'GRU':
layer['apply_reset_gate'] = 'after' if keras_layer['config']['reset_after'] else 'before'

if layer['return_sequences']:
output_shape = [input_shapes[0][0], layer['n_timesteps'], layer['n_out']]
else:
output_shape = [input_shapes[0][0], layer['n_out']]

if layer['return_state']:
raise Exception('"return_state" of {} layer is not yet supported.'.format(layer['class_name']))

return layer, output_shape
6 changes: 5 additions & 1 deletion hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def keras_to_hls(config):

#Define layers to skip for conversion to HLS
skip_layers = ['Dropout']
# Activation layers
activation_layers = ['Activation', 'LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'TernaryTanh']
# Recurrent layers
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU']
#All supported layers
supported_layers = get_supported_keras_layers() + skip_layers

Expand Down Expand Up @@ -310,7 +314,7 @@ def keras_to_hls(config):

print('Layer name: {}, layer type: {}, input shapes: {}, output shape: {}'.format(layer['name'], layer['class_name'], input_shapes, output_shape))
layer_list.append( layer )
if 'activation' in layer and layer['class_name'] not in ['Activation', 'LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'TernaryTanh']:# + qkeras_layers:
if 'activation' in layer and layer['class_name'] not in activation_layers + recurrent_layers:# + qkeras_layers:
act_layer = {}
act_layer['name'] = layer['name'] + '_' + layer['activation']
act_layer['activation'] = layer['activation']
Expand Down
Loading