Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved parsing of pytorch models using torch.FX - Clean #799

Merged
merged 11 commits into from
Jun 22, 2023
2 changes: 2 additions & 0 deletions hls4ml/backends/quartus/passes/pooling_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const bool count_pad = {count_pad};

static const nnet::Pool_Op pool_op = nnet::{pool_op};
typedef {accum_t.name} accum_t;
Expand All @@ -44,6 +45,7 @@
static const unsigned pad_bottom = {pad_bottom};
static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const bool count_pad = {count_pad};

static const nnet::Pool_Op pool_op = nnet::{pool_op};
typedef {accum_t.name} accum_t;
Expand Down
14 changes: 10 additions & 4 deletions hls4ml/backends/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def format(self, node):
def get_name(self):
return self.name

def _default_params(self, node):
params = {}
params.update(node.attributes)
# Convert all bool attributes to lowercase strings
params = {key: str(val).lower() if type(val) == bool else val for key, val in params.items()}

return params


class LayerConfigTemplate(Template):
def __init__(self, layer_class):
Expand All @@ -37,8 +45,7 @@ def __init__(self, layer_class):
super().__init__(name, layer_class, 'config_cpp')

def _default_config_params(self, layer):
params = {}
params.update(layer.attributes)
params = self._default_params(layer)
params['iotype'] = layer.model.config.get_config_value('IOType')
params['reuse'] = layer.get_attr('reuse_factor')

Expand All @@ -59,8 +66,7 @@ def __init__(self, layer_class, include_header=None):
self.include_header = include_header

def _default_function_params(self, layer):
params = {}
params.update(layer.attributes)
params = self._default_params(layer)
params['config'] = f'config{layer.index}'
params['input_t'] = layer.get_input_variable().type.name
params['output_t'] = layer.get_output_variable().type.name
Expand Down
2 changes: 2 additions & 0 deletions hls4ml/backends/vivado/passes/pooling_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const bool count_pad = {count_pad};
static const unsigned stride_width = {stride_width};
static const nnet::Pool_Op pool_op = nnet::{pool_op};
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
Expand All @@ -40,6 +41,7 @@
static const unsigned pad_bottom = {pad_bottom};
static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const bool count_pad = {count_pad};
static const nnet::Pool_Op pool_op = nnet::{pool_op};
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
static const unsigned reuse_factor = {reuse};
Expand Down
14 changes: 12 additions & 2 deletions hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def convert_from_pytorch_model(
"""Convert PyTorch model to hls4ml model based on the provided configuration.

Args:
model: PyTorch model to conert.
input_shape (list): The shape of the input tensor.
model: PyTorch model to convert.
input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None
output_dir (str, optional): Output directory of the generated HLS project. Defaults to 'my-hls-test'.
project_name (str, optional): Name of the HLS project. Defaults to 'myproject'.
input_data_tb (str, optional): String representing the path of input data in .npy or .dat format that will be
Expand All @@ -270,6 +270,16 @@ def convert_from_pytorch_model(
Raises:
Exception: If precision and reuse factor are not present in 'hls_config'.

Notes:
Pytorch uses the "channels_first" data format for its tensors, while hls4ml expects the "channels_last" format
used by keras. By default, hls4ml will automatically add layers to the model which transpose the inputs to the
"channels_last"format. Not that this is not supported for the "io_stream" io_type, for which the user will have
to transpose the input by hand before passing it to hls4ml. In that case the "inputs_channel_last" argument of
the "config_from_pytorch_model" function needs to be set to True. By default, the output of the model remains
in the "channels_last" data format. The "transpose_outputs" argument of the "config_from_pytorch_model" can be
used to add a layer to the model that transposes back to "channels_first". As before, this will not work for
io_stream.

Returns:
ModelGraph: hls4ml model.
"""
Expand Down
64 changes: 38 additions & 26 deletions hls4ml/converters/pytorch/convolution.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,95 @@
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler
from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format


@pytorch_handler('Conv1d')
def parse_conv1d_layer(pytorch_layer, layer_name, input_shapes, data_reader, config):
assert 'Conv1d' in pytorch_layer.__class__.__name__
def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Conv1d' in operation

layer = {}

layer['name'] = layer_name
layer['class_name'] = 'Conv1D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
# Input info
(layer['in_width'], layer['n_chan']) = parse_data_format(
input_shapes[0], 'channels_first'
) # Keras's default is channels_last

# Additional parameters
layer['n_filt'] = pytorch_layer.out_channels
layer['filt_width'] = pytorch_layer.kernel_size[0]
layer['stride_width'] = pytorch_layer.stride[0]
layer['pad_left'] = layer['pad_right'] = pytorch_layer.padding[0]
layer['dilation'] = pytorch_layer.dilation[0]
layer['n_filt'] = class_object.out_channels
layer['filt_width'] = class_object.kernel_size[0]
layer['stride_width'] = class_object.stride[0]
layer['dilation'] = class_object.dilation[0]

if pytorch_layer.padding[0] == 0: # No padding, i.e., 'VALID' padding in Keras/Tensorflow
if type(class_object.padding) is tuple:
padding = class_object.padding[0]
else:
padding = class_object.padding

if padding == 0: # No padding, i.e., 'VALID' padding in Keras/Tensorflow
layer['padding'] = 'valid'
else: # Only 'valid' and 'same' padding are available in Keras
layer['padding'] = 'same'

# Ouput info
(layer['out_width'], _, _) = compute_padding_1d(
layer['padding'], layer['in_width'], layer['stride_width'], layer['filt_width']
(layer['out_width'], pad_left, pad_right) = compute_padding_1d_pytorch(
padding, layer['in_width'], layer['stride_width'], layer['filt_width'], layer['dilation']
)
layer['pad_left'] = pad_left
layer['pad_right'] = pad_right

output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] # Channel first as default

return layer, output_shape


@pytorch_handler('Conv2d')
def parse_conv2d_layer(pytorch_layer, layer_name, input_shapes, data_reader, config):
assert 'Conv2d' in pytorch_layer.__class__.__name__
def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Conv2d' in operation

layer = {}

layer['name'] = layer_name
layer['class_name'] = 'Conv2D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
# Input info
(layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(
input_shapes[0], 'channels_first'
) # Keras's default is channels_last

# Additional parameters
layer['n_filt'] = pytorch_layer.out_channels
layer['filt_height'] = pytorch_layer.kernel_size[0]
layer['filt_width'] = pytorch_layer.kernel_size[1]
layer['stride_height'] = pytorch_layer.stride[0]
layer['stride_width'] = pytorch_layer.stride[1]
layer['dilation'] = pytorch_layer.dilation[0]
layer['pad_top'] = layer['pad_bottom'] = pytorch_layer.padding[0]
layer['pad_left'] = layer['pad_right'] = pytorch_layer.padding[1]

if all(x == 0 for x in pytorch_layer.padding): # No padding, i.e., 'VALID' padding in Keras/Tensorflow
layer['n_filt'] = class_object.out_channels
layer['filt_height'] = class_object.kernel_size[0]
layer['filt_width'] = class_object.kernel_size[1]
layer['stride_height'] = class_object.stride[0]
layer['stride_width'] = class_object.stride[1]
layer['dilation'] = class_object.dilation[0]
layer['pad_top'] = layer['pad_bottom'] = class_object.padding[0]
layer['pad_left'] = layer['pad_right'] = class_object.padding[1]

if all(x == 0 for x in class_object.padding): # No padding, i.e., 'VALID' padding in Keras/Tensorflow
layer['padding'] = 'valid'
else: # Only 'valid' and 'same' padding are available in Keras
layer['padding'] = 'same'

# Ouput info
(layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d(
layer['padding'],
(layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d_pytorch(
class_object.padding,
layer['in_height'],
layer['in_width'],
layer['stride_height'],
layer['stride_width'],
layer['filt_height'],
layer['filt_width'],
class_object.dilation[0],
class_object.dilation[1],
)

output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']]
Expand Down
94 changes: 70 additions & 24 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,91 @@
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler


# TODO: propagate use_bias info properly
# https://github.com/fastmachinelearning/hls4ml/issues/409
@pytorch_handler('Linear')
def parse_linear_layer(pytorch_layer, layer_name, input_shapes, data_reader, config):
assert 'Linear' in pytorch_layer.__class__.__name__
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Linear' in operation

layer = {}

layer['class_name'] = 'Dense'
layer['name'] = layer_name

layer['n_in'] = pytorch_layer.in_features
layer['n_out'] = pytorch_layer.out_features
layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias'])
if class_object is not None:
layer['n_in'] = class_object.in_features
layer['n_out'] = class_object.out_features
else:
raise Exception('parsing of torch.nn.functional.linear not supported yet, please use torch.nn.Linear class')

# Handling whether bias is used or not
assert pytorch_layer.bias is not None, "PyTorch Linear with bias=False not yet supported"
if pytorch_layer.bias is None:
if class_object.bias is None:
layer['use_bias'] = False
else:
layer['use_bias'] = True

output_shape = [input_shapes[0][0], layer['n_out']]
output_shape = input_shapes[0][:]
output_shape[-1] = layer['n_out']

return layer, output_shape


# TODO: propagate parametrized activation parameters
# https://github.com/fastmachinelearning/hls4ml/issues/409
# activation_layers = ['LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU']
activation_layers = ['Softmax', 'ReLU']
activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh']


@pytorch_handler(*activation_layers)
def parse_activation_layer(pytorch_layer, layer_name, input_shapes, data_reader, config):
def parse_activation_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
layer = {}

layer['class_name'] = pytorch_layer.__class__.__name__
layer['class_name'] = operation
layer['activation'] = layer['class_name']
layer['name'] = layer_name

if layer['class_name'] == 'ReLU':
layer['class_name'] = 'Activation'
# if layer['class_name'] != 'Activation':
# layer['activation'] = layer['class_name']
if node.op == 'call_module':
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = class_object.negative_slope
if layer['class_name'] == 'ELU':
layer['activ_param'] = class_object.alpha
if layer['class_name'] == 'PReLU':
layer['alpha_data'] = get_weights_data(data_reader, layer['name'], 'weight')
if layer['class_name'] == 'Threshold':
layer['activ_param'] = class_object.threshold
layer['class_name'] = 'ThresholdedReLU'
layer['activation'] = 'ThresholdedReLU'
if layer['activ_param'] < 0:
raise Exception('negative threshold values not supported')

if hasattr(node, 'dim'):
layer['axis'] = class_object.dim
else:
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = node.kwargs['negative_slope']
if layer['class_name'] == 'ELU':
layer['activ_param'] = node.kwargs['alpha']
if layer['class_name'] == 'Threshold':
layer['activ_param'] = node.args[1]
if layer['activ_param'] < 0:
raise Exception('negative threshold values not supported')
layer['class_name'] = 'ThresholdedReLU'
layer['activation'] = 'ThresholdedReLU'
if 'dim' in node.kwargs:
layer['axis'] = node.kwargs['dim']

output_shape = input_shapes[0]

return layer, output_shape


batchnorm_layers = ['BatchNorm2d', 'BatchNorm1d']
batchnorm_layers = ['BatchNorm2d', 'BatchNorm1d', 'Batch_norm']


@pytorch_handler(*batchnorm_layers)
def parse_batchnorm_layer(pytorch_layer, layer_name, input_shapes, data_reader, config):
assert 'BatchNorm' in pytorch_layer.__class__.__name__
def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'BatchNorm' in operation

layer = {}

Expand All @@ -63,8 +94,23 @@ def parse_batchnorm_layer(pytorch_layer, layer_name, input_shapes, data_reader,
layer['name'] = layer_name

# batchnorm para
layer['epsilon'] = pytorch_layer.eps
layer['use_gamma'] = layer['use_beta'] = pytorch_layer.affine
if node.op == 'call_module':
layer['epsilon'] = class_object.eps
layer['use_gamma'] = layer['use_beta'] = class_object.affine

if layer['use_gamma']:
layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'weight')
else:
layer['gamma_data'] = 1

if layer['use_beta']:
layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'bias')
else:
layer['beta_data'] = 0

layer['mean_data'], layer['variance_data'] = get_weights_data(
data_reader, layer['name'], ['running_mean', 'running_variance']
)

in_size = 1
for dim in input_shapes[0][1:]:
Expand Down
Loading