Skip to content

Commit

Permalink
[Relay] Remove overwriting of matmul shapes when they are static (apa…
Browse files Browse the repository at this point in the history
…che#13615)

In the Relay Matmul shape relation, we are a little over enthusiastic about unifying dynamic shapes. If one of the shapes is static, it does not need to be unified. This change only rewrites dynamic shapes to required static constraints.

* Remove overwriting of matmul shapes when they are static

* Simplify nesting

* Add shape check to dense tests.
  • Loading branch information
Josh Fromm authored and Mikael Sevenier committed Dec 29, 2022
1 parent 5b028dc commit 4610567
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
33 changes: 21 additions & 12 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,32 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end());
auto sa = A_shape.size();
auto sb = B_shape.size();
size_t index_swap_A;
size_t index_swap_B;
if (transpose_a && transpose_b) {
auto tmp = A_shape[sa - 2];
A_shape[sa - 2] = B_shape[sb - 1];
B_shape[sb - 1] = tmp;
index_swap_A = sa - 2;
index_swap_B = sb - 1;
} else if (transpose_a) {
auto tmp = A_shape[sa - 2];
A_shape[sa - 2] = B_shape[sb - 2];
B_shape[sb - 2] = tmp;
index_swap_A = sa - 2;
index_swap_B = sb - 2;
} else if (transpose_b) {
auto tmp = A_shape[sa - 1];
A_shape[sa - 1] = B_shape[sb - 1];
B_shape[sb - 1] = tmp;
index_swap_A = sa - 1;
index_swap_B = sb - 1;
} else {
auto tmp = A_shape[sa - 1];
A_shape[sa - 1] = B_shape[sb - 2];
B_shape[sb - 2] = tmp;
index_swap_A = sa - 1;
index_swap_B = sb - 2;
}

// Rewrite dynamic axes to static where constraints allow.
auto tmp = A_shape[index_swap_A];
if (A_shape[index_swap_A].as<tir::AnyNode>()) {
A_shape[index_swap_A] = B_shape[index_swap_B];
}
if (B_shape[index_swap_B].as<tir::AnyNode>()) {
B_shape[index_swap_B] = tmp;
}

// Update input types with new constrained shapes.
reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype));
reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype));
}
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16
import tvm.testing
from tvm.topi.utils import get_const_tuple

executor_kind = tvm.testing.parameter("graph", "vm")

Expand Down Expand Up @@ -695,6 +696,8 @@ def test_dense(executor_kind):
w = relay.var("w", relay.TensorType((k, n), dtype))
y = relay.nn.dense(x, w)
yy = run_infer_type(y)
# Confirm that input shape has not been rewritten to become dynamic.
assert get_const_tuple(yy.type_args[0].shape) == (4, 2)

n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
Expand Down

0 comments on commit 4610567

Please sign in to comment.