Skip to content

Commit

Permalink
add batch_dim support
Browse files Browse the repository at this point in the history
  • Loading branch information
zxy844288792 committed May 3, 2021
1 parent f862333 commit 99fd8ac
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 74 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,14 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
std::string mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(batch_dims)
.set_default(0)
.describe("The batch_dims over which to select values.");
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
Expand Down
119 changes: 87 additions & 32 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,9 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip",
std::string name = "T_take", std::string tag = kInjective) {
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
Array<PrimExpr> a_shape = a->shape;
Array<PrimExpr> out_shape = indices->shape;
PrimExpr a_size = 1;
Expand Down Expand Up @@ -846,6 +847,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param batch_dims The number of batch dimensions. By default is 0.
* \param axis The axis over which to select values. By default,
* the flattened input array is used.
* \param mode The mode for handling out of bound indices.
Expand All @@ -854,46 +856,99 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip",
std::string name = "T_take", std::string tag = kInjective) {
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
ICHECK_GE(axis, 0) << "axis out of bounds";
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];

int indices_len = static_cast<int>(indices->shape.size());
Array<PrimExpr> out_shape;
for (size_t i = 0; i < a->shape.size(); ++i) {
if (axis == static_cast<int>(i)) {
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
}
} else {
out_shape.push_back(a->shape[i]);

int batch_dims_ = batch_dims;
if (batch_dims_ != 0) {
ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";

if (batch_dims_ < 0) {
batch_dims_ = indices->shape.size() + batch_dims_;
}

ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
auto addr2 = indices->shape[i];
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
}
}

// The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
// a.shape[axis + 1:].

Array<PrimExpr> out_shape;
for (int i = 0; i < batch_dims_; ++i) {
out_shape.push_back(a->shape[i]);
}
for (int i = batch_dims_; i < axis; ++i) {
out_shape.push_back(a->shape[i]);
}
for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
out_shape.push_back(indices->shape[i]);
}
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
out_shape.push_back(a->shape[i]);
}

if (mode == "clip") {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
if (batch_dims_ == 0) {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
} else {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
}
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,14 +1806,19 @@ def _impl(inputs, attr, params, mod):
axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
batch_dims = 0
if int(attr.get("batch_dims", 0)) != 0:
raise tvm.error.OpAttributeUnImplemented("Attribute batch_dims is not supported")
batch_dims = int(attr.get("batch_dims", 0))
new_input = inputs[0:2]
return AttrCvt(
op_ = AttrCvt(
op_name="take",
extras={"axis": tvm.tir.const(axis, "int32")},
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class", "batch_dims"],
extras={
"axis": tvm.tir.const(axis, "int32"),
"batch_dims": tvm.tir.const(batch_dims, "int32"),
},
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"],
)(new_input, attr)
return op_

return _impl

Expand Down Expand Up @@ -3916,4 +3921,4 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
from tvm.relay.frontend.tensorflow2 import from_tensorflow as _from_tensorflow2
mod, params = _from_tensorflow2(graph, layout, shape, outputs)

return mod, params
return mod, params
15 changes: 10 additions & 5 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def _take_no_axis_shape_func(indices_shape, out_ndim):


@script
def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
def _take_with_axis_shape_func(data_shape, indices_shape, axis, batch_dims, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(axis):
out[i] = data_shape[i]
Expand All @@ -372,10 +372,10 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
for i in const_range(axis + 1, len(data_shape)):
out[i - 1] = data_shape[i]
else:
for i in const_range(len(indices_shape)):
out[axis + i] = indices_shape[i]
for i in const_range(len(indices_shape) - batch_dims):
out[axis + i] = indices_shape[i + batch_dims]
for i in const_range(axis + 1, len(data_shape)):
out[len(indices_shape) + i - 1] = data_shape[i]
out[len(indices_shape) + i - 1 - batch_dims] = data_shape[i]
return out


Expand All @@ -387,11 +387,16 @@ def take_shape_func(attrs, inputs, out_ndims):
if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
axis = get_const_int(attrs.axis)
batch_dims = get_const_int(attrs.batch_dims)
data_ndim = int(inputs[0].shape[0])
if inputs[1].shape:
indicies_ndim = int(inputs[1].shape[0])
if axis < 0:
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
if batch_dims < 0:
batch_dims += indicies_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), convert(batch_dims), out_ndims[0])]


@_reg.register_legalize("take")
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)


def take(data, indices, axis=None, mode="clip"):
def take(data, indices, axis=None, batch_dims=0, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -400,6 +400,9 @@ def take(data, indices, axis=None, mode="clip"):
The axis over which to select values. By default,
the flattened input array is used.
batch_dims : int
The number of batch dimensions. By default is 0.
mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
Expand All @@ -411,7 +414,7 @@ def take(data, indices, axis=None, mode="clip"):
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis, mode)
return _make.take(data, indices, batch_dims, axis, mode)


def full(fill_value, shape=(), dtype=""):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)


def take(a, indices, axis=None, mode="clip"):
def take(a, indices, axis=None, batch_dims=0, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -411,6 +411,9 @@ def take(a, indices, axis=None, mode="clip"):
The axis over which to select values. By default,
the flattened input array is used.
batch_dims : int
The number of batch dimensions. By default is 0.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
Expand All @@ -422,8 +425,8 @@ def take(a, indices, axis=None, mode="clip"):
ret : tvm.te.Tensor
"""
if axis is None:
return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis), mode)
return cpp.take(a, indices, int(batch_dims), mode)
return cpp.take(a, indices, int(batch_dims), int(axis), mode)


@tvm.target.generic_func
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);

Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode);

} // namespace relay
} // namespace tvm
Expand Down
23 changes: 17 additions & 6 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1192,15 +1192,24 @@ bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
int axis = static_cast<int>(param->axis->value);
int batch_dims = static_cast<int>(param->batch_dims->value);
if (axis < 0) axis += ndim_data;
if (batch_dims < 0) axis += ndim_indices;
ICHECK_LE(axis, ndim_data) << "axis should be with in data shape"
<< ", but got = " << axis;
ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape"
<< ", but got = " << batch_dims;
ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis"
<< ", but got = " << batch_dims;

oshape.reserve(ndim_data - 1 + ndim_indices);
for (int i = 0; i < axis; ++i) {
oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims);
for (int i = 0; i < batch_dims; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = batch_dims; i < axis; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = 0; i < ndim_indices; ++i) {
for (int i = batch_dims; i < ndim_indices; ++i) {
oshape.emplace_back(indices->shape[i]);
}
for (int i = axis + 1; i < ndim_data; ++i) {
Expand All @@ -1216,14 +1225,16 @@ Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& input
const auto* param = attrs.as<TakeAttrs>();
ICHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->mode)};
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->batch_dims, param->mode)};
} else {
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->axis, param->mode)};
return Array<te::Tensor>{
topi::take(inputs[0], inputs[1], param->batch_dims, param->axis, param->mode)};
}
}

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode) {
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) {
auto attrs = make_object<TakeAttrs>();
attrs->batch_dims = std::move(batch_dims);
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take");
Expand Down
15 changes: 9 additions & 6 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetVal
});

TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
if (args.size() == 4) {
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
int batch_dims = args[2];
*rv = take(args[0], args[1], batch_dims, mode);
} else {
ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments";
int batch_dims = args[2];
int axis = args[3];
std::string mode = args[4];
*rv = take(args[0], args[1], batch_dims, axis, mode);
}
});

Expand Down
Loading

0 comments on commit 99fd8ac

Please sign in to comment.