Skip to content

Commit

Permalink
[Relay] Enforce static dim for non-concat axis if one or more tensors…
Browse files Browse the repository at this point in the history
… have static dim (apache#7487)

* enforce static dim for non-concat axis

* assign any when all dims are dyn

* add missing case

* simplify

* add test

* only enforce static dim constraint if concat output is dynamic

* more update to concat type rel

* update tests

* fixed compile warning
  • Loading branch information
masahi authored and trevor-m committed Mar 2, 2021
1 parent e67955f commit 407e4e6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 17 deletions.
69 changes: 52 additions & 17 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,29 +101,64 @@ bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}

// Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
int data_length = static_cast<int>(tensor_tuple->fields.size());
std::vector<IndexExpr> oshape(ndim);
const size_t data_length = tensor_tuple->fields.size();

// Accumulate the concat axis output dim or decide if this is dynamic concat
bool is_dynamic_concat = false;
std::vector<TensorType> input_tensors;
IndexExpr concat_output_dim = first->shape[axis];
for (size_t i = 0; i < data_length; ++i) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
input_tensors.push_back(e);
if (e->shape[axis].as<AnyNode>()) {
is_dynamic_concat = true;
concat_output_dim = Any();
} else if (i > 0 && !is_dynamic_concat) {
// accumulate axis dimension
concat_output_dim += e->shape[axis];
}
}

oshape[axis] = concat_output_dim;

for (int i = 0; i < ndim; ++i) {
if (i == axis) {
// The concat axis is already handled above.
// The rest of the body sets the output shape for non-concat axes
continue;
}
std::vector<IndexExpr> non_any;
for (int j = 0; j < data_length; ++j) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[j]);
for (size_t j = 0; j < data_length; ++j) {
const auto& e = input_tensors[j];
if (!e->shape[i].as<AnyNode>()) {
non_any.push_back(e->shape[i]);
// accumulate axis dimension
if (j > 0 && i == axis && !oshape[i].as<AnyNode>()) {
oshape[i] += e->shape[i];
}
}
}
int non_any_size = static_cast<int>(non_any.size());
if (non_any_size != data_length) oshape[i] = Any();
if (i != axis) {
for (int k = 1; k < non_any_size; k++) {
if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
throw Error(
"relay.concatenate requires all tensors have the same shape "
"on non-concatenating axes");
}
size_t non_any_size = non_any.size();
for (size_t k = 1; k < non_any_size; k++) {
if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
throw Error(
"relay.concatenate requires all tensors have the same shape "
"on non-concatenating axes");
}

if (non_any_size == data_length) {
// All static case
oshape[i] = non_any[0];
} else if (non_any_size > 0 && is_dynamic_concat) {
// For non-concat axes, we want to enforce static shape constraint.
// However, if the concat axis is static, the output shape would become static while
// the input could be partially static/dynamic. To prevent runtime segfaults due to the lack
// of runtime input shape checking for such cases, static shape constraint is only enforced
// when the output concat axis is dynamic.
//
// Examples (both concat on the first axis):
// * [(?, 3), (?, ?)] -> (?, 3)
// * [(1, 3), (1, ?)] -> (2, ?)
oshape[i] = non_any[0];
} else {
oshape[i] = Any();
}
}

Expand Down
21 changes: 21 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,27 @@ def test_any_concat():
ref = np.concatenate(x_np, axis=0)
check_result(x_np, mod, ref)

def test_oshape(in_vars, axis, oshape):
z = relay.op.concatenate(in_vars, axis=axis)
mod = tvm.IRModule()
mod["main"] = relay.Function(in_vars, z)
typed_mod = relay.transform.InferType()(mod)
assert typed_mod["main"].body.checked_type == relay.TensorType(oshape, dtype="float32")

x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)]
x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32"))

test_oshape(x, 0, (relay.Any(), 3))
test_oshape(x, 1, (relay.Any(), relay.Any()))

# [(1, 3), (1, ?)] -> (2, ?)
x = [
relay.var("x", shape=(1, 3), dtype="float32"),
relay.var("x", shape=(1, relay.Any()), dtype="float32"),
]
test_oshape(x, 0, (2, relay.Any()))
test_oshape(x, 1, (1, relay.Any()))


def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
x = relay.var("x", shape=x_shape, dtype="float32")
Expand Down

0 comments on commit 407e4e6

Please sign in to comment.