Skip to content

Commit

Permalink
[Caffe Frontend] extending Eltwise to handle multiple inputs (#8136)
Browse files Browse the repository at this point in the history
* [Caffe Frontend] adding Reduction op

* reformatting Reduction op test script

* reformatting Reduction test script

* [Caffe frontend] Reduction op
- adding more test cases; handling '0 < axis < num_axes - 1' case to give the result equivalent to Caffe framework
- skipping Relay multiplication if coeff is 1

Signed-off-by: zotanika <zotanika@gmail.com>

* linting test script

* linting

* [Caffe Frontend] Supporting multiple grouped(channel-wise) Deconv op

* Handling group > 1 cases, assuming group == output channels
* Decomposed into Relay split, transposed conv, and multi-leveled concatenation.
* Added some test cases.

Signed-off-by: zotanika <zotanika@gmail.com>

* [Caffe Frontend] supporting variable number of inputs for Eltwise

* extra handling of rest inputs for PROD, SUM, MAX operations
* extra testcases

Signed-off-by: zotanika <zotanika@gmail.com>

* formatting fix

* [Caffe Frontend] reverting codes related Reduction for splitting PR

* Revert "[Caffe Frontend] Supporting multiple grouped(channel-wise) Deconv op"

This reverts commit 43e25e5.

* instant fix against docker format error

* instant fix against docker format error

* instant fix against docker format error
  • Loading branch information
zotanika authored Jan 17, 2022
1 parent be0677d commit 3c8de42
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
25 changes: 22 additions & 3 deletions python/tvm/relay/frontend/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,13 @@ def convert_flatten(self, op):
def convert_eltwise(self, op):
"""Convert Eltwise layer"""
inputs = op.bottom
assert len(inputs) == 2, "input tensors length should be 2"
assert len(inputs) >= 2, "input tensors length should be larger than 2"

# gethering initial 2 input expressions
lhs_expr = self.exp_tab.get_expr(inputs[0])
rhs_expr = self.exp_tab.get_expr(inputs[1])

lhs_shape = _infer_shape(lhs_expr)
rhs_shape = _infer_shape(rhs_expr)

assert lhs_shape == rhs_shape, "input tensors shape should be equal"

eltwise_params = op.eltwise_param
Expand All @@ -100,6 +99,11 @@ def convert_eltwise(self, op):

if eltwise_type_dict[eltwise_type] == "PROD":
out = _op.multiply(lhs_expr, rhs_expr)
# for rest inputs
for i in range(len(inputs) - 2):
extra_expr = self.exp_tab.get_expr(inputs[i + 2])
assert _infer_shape(out) == _infer_shape(extra_expr)
out = _op.multiply(out, extra_expr)
elif eltwise_type_dict[eltwise_type] == "SUM":
if coeff:
left_coeff_expr = self.exp_tab.new_const(np.asarray(coeff[0], np.float32))
Expand All @@ -109,8 +113,23 @@ def convert_eltwise(self, op):
out = _op.add(lhs_expr_scale, rhs_expr_scale)
else:
out = _op.add(lhs_expr, rhs_expr)
# for rest inputs
for i in range(len(inputs) - 2):
extra_expr = self.exp_tab.get_expr(inputs[i + 2])
assert _infer_shape(out) == _infer_shape(extra_expr)
if coeff:
coeff_expr = self.exp_tab.new_const(np.asarray(coeff[i + 2], np.float32))
extra_expr_scale = _op.multiply(extra_expr, coeff_expr)
out = _op.add(out, extra_expr_scale)
else:
out = _op.add(out, extra_expr)
elif eltwise_type_dict[eltwise_type] == "MAX":
out = _op.maximum(lhs_expr, rhs_expr)
# for rest inputs
for i in range(len(inputs) - 2):
extra_expr = self.exp_tab.get_expr(inputs[i + 2])
assert _infer_shape(out) == _infer_shape(extra_expr)
out = _op.maximum(out, extra_expr)
else:
raise tvm.error.OpNotImplemented(
"eltwise_type {} is not supported for frontend Caffe.".format(eltwise_type)
Expand Down
39 changes: 39 additions & 0 deletions tests/python/frontend/caffe/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,45 @@ def test_forward_Eltwise():
operation=1,
coeff=[0.5, 1],
)
_test_eltwise(
[
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
],
operation=0,
)
_test_eltwise(
[
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
],
operation=1,
)
_test_eltwise(
[
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
],
operation=2,
)
_test_eltwise(
[
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
np.random.rand(1, 3, 10, 11).astype(np.float32),
],
operation=1,
coeff=[0.5, 1, 0.2, 1.8, 3.1, 0.1],
)


#######################################################################
Expand Down

0 comments on commit 3c8de42

Please sign in to comment.