Skip to content

Commit

Permalink
[RELAY]Reduce ops
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Oct 18, 2018
1 parent 4001569 commit 21acafe
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 89 deletions.
10 changes: 10 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ This level enables additional math and transform operators.
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
tvm.relay.sum
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.prod


**Level 5: Vision/Image Operators**
Expand Down Expand Up @@ -187,6 +192,11 @@ Level 4 Definitions
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
.. autofunction:: tvm.relay.sum
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod


Level 5 Definitions
Expand Down
193 changes: 193 additions & 0 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,196 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
"""

return _make.argmin(data, axis, keepdims, exclude)


def sum(data, axis=None, keepdims=False, exclude=False):
"""Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""

return _make.sum(data, axis, keepdims, exclude)


def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""

return _make.max(data, axis, keepdims, exclude)


def min(data, axis=None, keepdims=False, exclude=False):
"""Computes the min of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""

return _make.min(data, axis, keepdims, exclude)


def mean(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""

return _make.mean(data, axis, keepdims, exclude)


def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""

return _make.prod(data, axis, keepdims, exclude)
111 changes: 111 additions & 0 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,34 @@ bool ArgReduceRel(const Array<Type>& types,
return true;
}

/*!
* \brief ReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape;
for (auto i : data->shape) {
in_shape.push_back(i);
}

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);

// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}

#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
Expand Down Expand Up @@ -213,5 +241,88 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);


RELAY_REGISTER_REDUCE_OP("sum")
.describe(R"code(Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("min")
.describe(R"code(Computes the min of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 21acafe

Please sign in to comment.