Skip to content

Commit

Permalink
weight inference
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Oct 12, 2018
1 parent bb16898 commit 5eee8f7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
40 changes: 20 additions & 20 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,26 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
}
};
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
}
};

/*! \brief Attributes for leaky relu operator */
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
};


/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
Expand Down Expand Up @@ -270,26 +290,6 @@ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
.describe("A lower bound value for the norm, to avoid division by 0.");
TVM_ATTR_FIELD(axis)
.describe("Axis over the normalization applied.");


/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
}
};


/*! \brief Attributes for leaky rely operator */
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
};

Expand Down
16 changes: 15 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bool DenseRel(const Array<Type>& types,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;

const DenseAttrs* param = attrs.as<DenseAttrs>();
Expand All @@ -33,9 +34,22 @@ bool DenseRel(const Array<Type>& types,

Array<tvm::Expr> oshape = data->shape;
if (param->units.defined()) {
Array<tvm::Expr> dshape = data->shape;

// validate the weight shape is proper if defined
if (weight != nullptr) {
CHECK(reporter->AssertEQ(weight->shape[0], dshape[dshape.size() - 1]))
<< "Dense: shape of weight is inconsistent with input data.";
CHECK(reporter->AssertEQ(weight->shape[1], param->units))
<< "Dense: shape of weight is inconsistent with units.";
} else {
// Assign weight type
std::vector<IndexExpr> wshape({dshape[dshape.size() - 1], param->units});
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
}

oshape.Set((oshape.size() - 1), param->units);
} else {
const auto* weight = types[1].as<TensorTypeNode>();
if (weight == nullptr) return false;
Array<tvm::Expr> wshape = weight->shape;
oshape.Set((oshape.size() - 1), wshape[wshape.size() - 1]);
Expand Down

0 comments on commit 5eee8f7

Please sign in to comment.