Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Enforce static dim for non-concat axis if one or more tensors have static dim #7487

Merged
merged 9 commits into from
Feb 26, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 22, 2021

Currently, the concat type relation assigns output shape Any() to non-concat axes if there is even one Any in corresponding input tensor shapes. But it is clear that for non-concat axes, if there is one or more static dim in input tensors, the input static dims and the output dim must all be the same static value.

For example, after this PR, the typing for the concat op with concat axis == 0 changes as follows:

  • Current: [(?, 3), (?, ?)] -> (?, ?)
  • After this PR: [(?, 3), (?, ?)] -> (?, 3)

This, together with #7479, removes all unnecessary any_dim from PyTorch MaskRCNN, and significantly simplifies dynamic injective kernels. Below is an example of super-complicated and slow dynamic injective op, due to too many any_dim. This has been the bottleneck in PyTorch MaskRCNN, but after this PR, it becomes reasonable and no longer bottleneck.

Thanks to this optimization, PyTorch MaskRCNN runs 30 milli seconds (!!) faster now.

please review @kevinthesun @anijain2305 @mbrookhart @jwfromm

@masahi masahi force-pushed the concat-rel-more-static branch from 56d53b2 to 047a0b0 Compare February 22, 2021 09:39
@masahi masahi marked this pull request as ready for review February 22, 2021 11:13
Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition. I think, we might have similar scenarios in TF models as well @trevor-m

LGTM

@masahi
Copy link
Member Author

masahi commented Feb 24, 2021

@mbrookhart Can you take a look as well?

@mbrookhart
Copy link
Contributor

mbrookhart commented Feb 24, 2021

My main concern with this, and please correct me if I'm wrong, is concatenating something like shape [1, 3] and shape [1, ?] on axis 0. In that case, I believe this will return [2, 3], and we don't do any runtime checks on the second argument. If the model is wrong, and, say, the input comes in as [1, 2] instead of [1, 3], we'll end up with a runtime segfault in generated code, which is very difficult to debug.

@masahi
Copy link
Member Author

masahi commented Feb 24, 2021

@mbrookhart I'm not sure about this, but doesn't concat shape function do the runtime check to prevent segfault? I mean, for the shape func, the case where all dims are dynamic vs dims are mixed static/dynamic shouldn't make a difference for what the shape func should do?

@mbrookhart
Copy link
Contributor

The shape function only runs if the output shape of the op is dynamic at compile time:

} else if (IsDynamic(ret_type)) {
Function func = Downcast<Function>(cn->op);
return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type);

auto out_shapes = EmitShapeFunc(scope, func, new_args);

// Insert the shape function given a primitive function.
Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
const std::vector<Expr>& new_args) {
Array<Expr> shape_func_ins;
auto engine = CompileEngine::Global();
CCacheKey key(func, target_host_);
auto cfunc = engine->LowerShapeFunc(key);
auto input_states = cfunc->shape_func_param_states;
Array<Integer> is_inputs;
int input_pos = 0;
TVMContext cpu_ctx = default_context_;
CHECK_EQ(new_args.size(), input_states.size());
for (size_t i = 0; i < new_args.size(); ++i) {
Expr arg = new_args[i];
Type ty;
if (const auto* vn = arg.as<VarNode>()) {
ty = vn->type_annotation;
} else {
ty = arg->checked_type();
}
int state = input_states[i]->value;
// Pass Shapes
if (state == 2) {
std::vector<Expr> exprs = FromTupleType(ty, arg);
for (size_t j = 0; j < exprs.size(); ++j) {
Expr sh_of = ExprMutator::Mutate(ShapeOf(exprs[j]));
Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
input_pos++;
}
is_inputs.push_back(0);
} else if (state == 1) {
auto new_arg = ExprMutator::Mutate(arg);
auto ctx = GetContext(arg);
if (ctx.device_type != cpu_ctx.device_type) {
new_arg = DeviceCopy(new_arg, ctx.device_type, cpu_ctx.device_type);
}
Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
input_pos++;
is_inputs.push_back(1);
} else {
// TODO(@jroesch): handle 3rd case
LOG(FATAL) << "unsupported shape function input state";
}
}
Array<Expr> out_shapes;
for (size_t i = 0; i < cfunc->outputs.size(); ++i) {
auto out = cfunc->outputs[i];
auto tt = TensorType(out->shape, out->dtype);
// Put shape func on CPU. This also ensures that everything between
// shape_of and shape_func are on CPU.
auto alloc = MakeStaticAllocation(scope, tt, cpu_ctx, std::to_string(i));
Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr));
alloc = scope->Push(shape_func_out_var, alloc);
out_shapes.push_back(alloc);
}
auto shape_call = ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs);
Var shape_func_var("shape_func", Type(nullptr));
scope->Push(shape_func_var, shape_call);
return out_shapes;
}

@mbrookhart
Copy link
Contributor

mbrookhart commented Feb 24, 2021

If you have dynamically shaped inputs and a statically shaped output, the shape func won't run, so you wont be able to update/check your assumption on that dynamic input at runtime.

@masahi
Copy link
Member Author

masahi commented Feb 24, 2021

Thanks, that makes sense. I wish there was a way to decouple output shape calculation and input shape check from shape func, so that we can add input shape check for case like this (dynamic input shape + static output shape).

@masahi masahi marked this pull request as draft February 24, 2021 21:44
@masahi masahi force-pushed the concat-rel-more-static branch from ae3dce3 to 2d99cd6 Compare February 25, 2021 12:09
@masahi masahi force-pushed the concat-rel-more-static branch from 2d99cd6 to 5d66703 Compare February 25, 2021 19:10
Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@masahi masahi marked this pull request as ready for review February 25, 2021 21:08
@masahi masahi merged commit 63ea8e1 into apache:main Feb 26, 2021
@masahi
Copy link
Member Author

masahi commented Feb 26, 2021

Thanks @anijain2305 @mbrookhart

Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
… have static dim (apache#7487)

* enforce static dim for non-concat axis

* assign any when all dims are dyn

* add missing case

* simplify

* add test

* only enforce static dim constraint if concat output is dynamic

* more update to concat type rel

* update tests

* fixed compile warning
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
… have static dim (apache#7487)

* enforce static dim for non-concat axis

* assign any when all dims are dyn

* add missing case

* simplify

* add test

* only enforce static dim constraint if concat output is dynamic

* more update to concat type rel

* update tests

* fixed compile warning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants