Skip to content

Commit

Permalink
fix quant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Jun 24, 2022
1 parent a638428 commit 5fd3aba
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/qonnx/converters/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tensorflow as tf
import tf2onnx
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS
from collections import OrderedDict

from finn.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup_model
Expand All @@ -16,6 +17,9 @@
"QDepthwiseConv2DBatchnorm",
]

# Skip remove_identity optimizer
del tf2onnx.optimizer._optimizers['remove_identity']


def add_value_info_for_constants(model: onnx.ModelProto):
"""
Expand Down Expand Up @@ -101,14 +105,15 @@ def iterate_model(model):


def _strip_qkeras_model(model):
quantizers = {}
quantizers = OrderedDict()

def extract_quantizers(layer):
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer)
if layer_quantizers:
layer_quantizers = {
k: None if v == "None" else v for k, v in layer_quantizers.items()
} # Get rid of 'None' strings
layer_quantizers["input"] = layer.input.name
quantizers[layer.name] = layer_quantizers

layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None)
Expand Down
21 changes: 17 additions & 4 deletions src/qonnx/converters/qkeras/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@ def get_qkeras_onnx_handlers(all_quantizers):
"MatMul": (dense_handler, ["MatMul", all_quantizers]),
"BiasAdd": (bias_handler, ["BiasAdd", all_quantizers]),
"Relu": (relu_handler, ["Relu", all_quantizers]),
"Identity": (identity_handler, ["Identity", all_quantizers]),
}


def _extract_node_name(onnx_name, keras_names):
def _extract_node_name(onnx_node, keras_quantizers):
onnx_name = onnx_node.name
keras_names = keras_quantizers.keys()
for keras_name in keras_names:
match = "/" + keras_name + "/"
if match in onnx_name:
return keras_name
elif "Identity" in onnx_name:
onnx_input = onnx_node.input[0]
keras_input = keras_quantizers[keras_name]["input"]
if keras_input in onnx_input:
return keras_name

return None


def qlayer_handler(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(name, all_quantizers.keys())
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
Expand Down Expand Up @@ -79,7 +87,7 @@ def qlayer_handler(ctx, node, name, args):

def qact_handler(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(name, all_quantizers.keys())
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
Expand Down Expand Up @@ -119,7 +127,7 @@ def bias_handler(ctx, node, name, args):
BiasAdd.version_1(ctx, node)

all_quantizers = args[0]
keras_name = _extract_node_name(name, all_quantizers.keys())
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
Expand All @@ -140,3 +148,8 @@ def bias_handler(ctx, node, name, args):
def relu_handler(ctx, node, name, args):
DirectOp.version_1(ctx, node)
qact_handler(ctx, node, name, args)


def identity_handler(ctx, node, name, args):
DirectOp.version_1(ctx, node)
qact_handler(ctx, node, name, args)
1 change: 1 addition & 0 deletions src/qonnx/converters/qkeras/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def convert_ternary(tensor, quantizer):
ternary = qkeras.ternary()
t = ternary.default_threshold
assert t == 0.5, "ternary - only threshold 0.5 is supported"
# note that if assertions fail, Quant node is not inserted, but model is still converted; this seems to be unexpected behavior
scale = 1.0
zero_point = 0
bit_width = 2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_keras_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
quantized_bits(4, 2, 0, alpha=1),
quantized_bits(2, 2, 1, alpha=1),
quantized_bits(2, 1, 1, alpha=1),
ternary(alpha=1),
ternary(alpha=1, threshold=0.5),
binary(alpha=1),
]
act_quantizers_ids = list(range(len(act_quantizers)))
Expand Down

0 comments on commit 5fd3aba

Please sign in to comment.