Skip to content

Commit

Permalink
Review comments fixed and rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Oct 9, 2018
1 parent 852ef30 commit f437e2b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,4 @@ def leaky_relu(data, alpha):
result : relay.Expr
The computed result.
"""
return _make.leaky_relu(data, alpha)
return _make.leaky_relu(data, alpha)
14 changes: 7 additions & 7 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ bool DenseRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const auto* weight = types[1].as<TensorTypeNode>();
if (weight == nullptr) return false;

const DenseAttrs* param = attrs.as<DenseAttrs>();
CHECK(param != nullptr);

CHECK(static_cast<int>(data->shape.size()) != 0);
Array<tvm::Expr> wshape = weight->shape;

Array<tvm::Expr> oshape = data->shape;
if (param->units.defined()) {
CHECK(reporter->AssertEQ(param->units, wshape[wshape.size()-1]));
oshape.Set((oshape.size() - 1), param->units);
} else {
const auto* weight = types[1].as<TensorTypeNode>();
if (weight == nullptr) return false;
Array<tvm::Expr> wshape = weight->shape;
oshape.Set((oshape.size() - 1), wshape[wshape.size() - 1]);
}

Array<tvm::Expr> oshape = data->shape;
oshape.Set((oshape.size() - 1), wshape[wshape.size() - 1]);

// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,23 @@ def test_dense_infer_type():
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, h, 2), "float32")

ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))

wh, ww = tvm.var("wh"), tvm.var("ww")
w = ib.param("w", relay.ty.TensorType((wh, ww), "float32"))

with ib.function(x, w) as func:
ib.ret(relay.nn.dense(x.var, w.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, h, ww), "float32")


if __name__ == "__main__":
test_conv2d_infer_type()
test_pool2d_infer_type()
Expand Down

0 comments on commit f437e2b

Please sign in to comment.