Skip to content

Commit

Permalink
[Relay][OP] Fix bias_add default axis (apache#2829)
Browse files Browse the repository at this point in the history
* Fix bias add default axis

* update

* Fix canonicalize ops for bias_add
  • Loading branch information
icemelon authored and wweic committed Apr 10, 2019
1 parent 1ea177f commit 5ce9d59
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _mx_fully_connected(inputs, attrs):
res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2])
res = _op.nn.bias_add(res, inputs[2], axis=-1)
return res


Expand Down Expand Up @@ -413,7 +413,7 @@ def _mx_batch_dot(inputs, attrs):
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)
return _op.nn.batch_matmul(a, b)


def _mx_arange(inputs, attrs):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def get_net(batch_size,

flatten = relay.nn.batch_flatten(pool)
fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"))
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1)
inception_v3 = relay.nn.softmax(data=fc1)
args = relay.ir_pass.free_vars(inception_v3)
return relay.Function(args, inception_v3)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
if not bias:
bias = relay.var(name + "_bias")
data = relay.nn.dense(data, weight, units, **kwargs)
data = relay.nn.bias_add(data, bias)
data = relay.nn.bias_add(data, bias, axis=-1)
return data
6 changes: 3 additions & 3 deletions python/tvm/relay/testing/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def get_net(batch_size,
dtype=dtype)
data = relay.nn.batch_flatten(data)
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"))
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
act1 = relay.nn.relu(fc1)
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"))
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"), axis=-1)
act2 = relay.nn.relu(fc2)
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"))
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
mlp = relay.nn.softmax(data=fc3)
args = relay.ir_pass.free_vars(mlp)
return relay.Function(args, mlp)
Expand Down
6 changes: 5 additions & 1 deletion src/relay/pass/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class BiasAddSimplifier : public ExprMutator {

auto ttype = n->args[0]->type_as<TensorTypeNode>();
size_t n_dim = ttype->shape.size();
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
int axis = param->axis;
if (axis < 0) {
axis += n_dim;
}
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis});
Expr ret = Add(call->args[0], expanded_bias);
ret->checked_type_ = n->checked_type_;
return ret;
Expand Down

0 comments on commit 5ce9d59

Please sign in to comment.