diff --git a/hls4ml/backends/backend.py b/hls4ml/backends/backend.py index 05385617e9..b121044629 100644 --- a/hls4ml/backends/backend.py +++ b/hls4ml/backends/backend.py @@ -26,6 +26,11 @@ def _init_file_optimizers(self): opt_path = os.path.dirname(inspect.getfile(self.__class__)) + '/passes' module_path = self.__module__[:self.__module__.rfind('.')] + '.passes' file_optimizers = extract_optimizers_from_path(opt_path, module_path, self) + for base in self.__class__.__bases__: + opt_path = os.path.dirname(inspect.getfile(base)) + '/passes' + module_path = base.__module__[:base.__module__.rfind('.')] + '.passes' + base_optimizers = extract_optimizers_from_path(opt_path, module_path, self) + file_optimizers.update(base_optimizers) return file_optimizers def _get_layer_initializers(self): diff --git a/hls4ml/backends/fpga/passes/__init__.py b/hls4ml/backends/fpga/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/quartus/passes/bn_quant.py b/hls4ml/backends/fpga/passes/bn_quant.py similarity index 98% rename from hls4ml/backends/quartus/passes/bn_quant.py rename to hls4ml/backends/fpga/passes/bn_quant.py index 91b242fd23..b51d7610f1 100644 --- a/hls4ml/backends/quartus/passes/bn_quant.py +++ b/hls4ml/backends/fpga/passes/bn_quant.py @@ -15,7 +15,7 @@ batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});' -bn_include_list = ['nnet_utils/nnet_batchnorm.h'] +bn_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate): def __init__(self): diff --git a/hls4ml/backends/vivado/passes/bn_quant.py b/hls4ml/backends/vivado/passes/bn_quant.py deleted file mode 100644 index aebd4dee8f..0000000000 --- a/hls4ml/backends/vivado/passes/bn_quant.py +++ /dev/null @@ -1,157 +0,0 @@ -import numpy as np -import re - -from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType -from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer -from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.backends.fpga.fpga_layers import BatchNormalizationQuantizedTanh - -batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{ - static const unsigned n_in = {n_in}; - static const unsigned n_filt = {n_filt}; - static const unsigned io_type = nnet::{iotype}; - static const unsigned reuse_factor = {reuse}; -}};\n""" - -batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});' -bn_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] - -class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate): - def __init__(self): - super().__init__(BatchNormalizationQuantizedTanh) - self.template = batchnorm_quantized_tanh_config_template - - def format(self, node): - params = self._default_config_params(node) - params['n_in'] = node.get_input_variable().size_cpp() - - return self.template.format(**params) - -class BatchNormalizationQuantizedTanhFunctionTemplate(FunctionCallTemplate): - def __init__(self): - super().__init__(BatchNormalizationQuantizedTanh, include_header=bn_include_list) - self.template = batchnorm_quantized_tanh_function_template - - def format(self, node): - params = self._default_function_params(node) - if node.get_attr('quantize') == 2: - params['quantize'] = 'binary' - params['threshold'] = node.get_weights('threshold').name - elif node.get_attr('quantize') == 3: - params['quantize'] = 'ternary' - params['threshold'] = node.get_weights('threshold_hi').name + ', ' + node.get_weights('threshold_lo').name - - return self.template.format(**params) - -def register_bn_quant(backend): - # Register the layer types to the layer map - register_layer('BatchNormalizationQuantizedTanh', BatchNormalizationQuantizedTanh) - - # Register the optimization passes - backend.register_pass('merge_batch_norm_quantized_tanh', MergeBatchNormAndQuantizedTanh) - backend.register_pass('quantize_dense_output', QuantizeDenseOutput) - - # Register template passes - backend.register_template(BatchNormalizationQuantizedTanhConfigTemplate) - backend.register_template(BatchNormalizationQuantizedTanhFunctionTemplate) - - -class MergeBatchNormAndQuantizedTanh(OptimizerPass): - def match(self, node): - is_match = (node.class_name == 'Activation' - and node.get_attr('activation') in ['binary', 'binary_tanh', 'ternary', 'ternary_tanh'] - or node.class_name == 'TernaryTanh') - is_match = is_match and isinstance(node.get_input_node(), BatchNormalization) - return is_match - - def transform(self, model, node): - bn_layer = node.get_input_node() - # Make a new layer with the new attributes - quantize = 0 - if 'binary' in node.get_attr('activation'): - quantize = 2 - if 'ternary' in node.get_attr('activation'): - quantize = 3 - attrs = { - 'name' : bn_layer.get_attr('name'), - 'original_name' : bn_layer.get_attr('name'), - 'class_name' : 'BatchNormalizationQuantizedTanh', - 'n_in' : bn_layer.get_attr('n_in'), - 'n_out' : bn_layer.get_attr('n_in'), - 'n_filt' : bn_layer.get_attr('n_filt'), - 'quantize' : quantize, - 'Trace' : bn_layer.get_attr('Trace') - } - bnbt_layer = model.make_node(BatchNormalizationQuantizedTanh, 'bnbt_' + bn_layer.name, attrs, bn_layer.inputs) - bnbt_layer.set_thresholds(bn_layer.get_weights('scale').data, bn_layer.get_weights('bias').data, node.get_attr('threshold',0.5)) - # Remove the BatchNormalization layer - model.remove_node(bn_layer, rewire=True) - # Replace the old Activation layer with this one - model.replace_node(node, bnbt_layer) - - return True - -class QuantizeDenseOutput(OptimizerPass): - def match(self, node): - is_dense = node.class_name == 'Dense' - input_node = node.get_input_node() - is_input_bnqt = input_node is not None and input_node.class_name == 'BatchNormalizationQuantizedTanh' - quantizer = node.get_attr('weight_quantizer') - is_binary_ternary = quantizer is not None and (quantizer.__class__.__name__ == 'BinaryQuantizer' or quantizer.__class__.__name__ == 'TernaryQuantizer') - return is_dense and is_input_bnqt and is_binary_ternary - - def transform(self, model, node): - # Compute the required precision and update the variables - # Number of bits for output is log2 of number of input nodes - # Since this is the number of uint<1>'s which are summed - nbits = int(np.ceil(np.log2(node.attributes['n_in'])) + 2) - out_type = IntegerPrecisionType(width=nbits) - accum_t = NamedType('layer{}_accum_t'.format(node.index), out_type) - node.set_attr('accum_t', accum_t) - out_var = node.get_output_variable() - out_var.type.precision = out_type - - quantized_data = None - quantized_precision = None - quantizer = node.get_attr('weight_quantizer') - if quantizer.__class__.__name__ == 'BinaryQuantizer': - quantized_precision = XnorPrecisionType() - elif quantizer.__class__.__name__ == 'TernaryQuantizer': - quantized_precision = IntegerPrecisionType(width=2) - else: - print('WARNING: Unknown quantizer - {}. Bailing out'.format(quantizer.__class__.__name__)) - return False - quantizer.bits = quantized_precision.width - quantizer.hls_type = quantized_precision - quantized_data = quantizer(node.weights['weight'].data) - - weights = node.weights['weight'] - weights.data = quantized_data - weights.type.name = 'weight{index}_t'.format(index=node.index) - weights.update_precision(quantized_precision) - - bias = node.weights['bias'] - bias.data = np.zeros(shape=(node.get_attr('n_out'))) - bias.type.name = 'bias{index}_t'.format(index=node.index) - bias.nzeros = 0 - bias.update_precision(quantized_precision) - - # If followed by the BatchNormalizationBinaryTanh, update its input - # Also requantise the weights - bd_out_nodes = node.get_output_nodes() - for out_node in bd_out_nodes: - if isinstance(out_node, BatchNormalizationQuantizedTanh): - var_names = [] - if quantizer.__class__.__name__ == 'BinaryQuantizer': - var_names.append('threshold') - elif quantizer.__class__.__name__ == 'TernaryQuantizer': - var_names.append('threshold_hi') - var_names.append('threshold_lo') - for var_name in var_names: - threshold_var = out_node.weights[var_name] - threshold_var.update_precision(out_type) - threshold_var.data = np.floor(threshold_var.data) - - return False - diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h new file mode 100644 index 0000000000..a87d813151 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h @@ -0,0 +1,33 @@ +// +// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks +// +// Copyright (C) 2017 EJ Kreinar +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +/* +* PLACEHOLDER - The common pass bn_quant.py includes both parallel and streaming BN; streaming is currently not supported in Quartus +*/ + +#ifndef NNET_BATCHNORM_STREAM_H_ +#define NNET_BATCHNORM_STREAM_H_ + +#include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" + +namespace nnet {} + +#endif