Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][TensorRT] Fixes for explicit batch mode, Support reduce to scalar, Support split op #7967

Merged
merged 3 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _func_wrapper(expr):

def reduce_annotate_fn(attrs, args, op_name):
"""Helper for reduce operations."""
if not attrs.axis or len(attrs.axis) == 0:
if get_tensorrt_use_implicit_batch_mode() and (not attrs.axis or len(attrs.axis) == 0):
logger.info("%s: cannot reduce to scalar.", op_name)
return False
if attrs.exclude:
Expand Down Expand Up @@ -317,17 +317,18 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
for arg in args
]

# RelayVM + TRT doesn't support scalar addition yet.
for shape in shapes:
if len(shape) < 1:
return False
# Scalars require explicit batch mode.
if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
and len(shapes[0]) > 0
and len(shapes[1]) > 0
and shapes[0][0] == shapes[1][0]
and shapes[0][0] != 1
and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
Expand Down Expand Up @@ -552,6 +553,19 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("split")
def split_annotate_fn(expr):
"""Check if split is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("split: can't modify batch dimension.")
return False
return True


@_register_external_dynamic_check_func("nn.conv2d_transpose")
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d_transpose is supported by TensorRT."""
Expand Down Expand Up @@ -870,6 +884,11 @@ def visit_call(self, call):
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
"sum",
"prod",
"max",
"min",
"mean",
]
)
if isinstance(call.op, tvm.tir.op.Op):
Expand Down Expand Up @@ -968,6 +987,7 @@ def visit_call(self, call):
# Create new pruned module
new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod


Expand Down
31 changes: 31 additions & 0 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
SetPadNodeAttribute(node, cn);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(node, cn);
} else if (name == "split") {
SetSplitNodeAttribute(node, cn);
} else {
SetCallNodeAttribute(node, cn);
}
Expand Down Expand Up @@ -172,6 +174,35 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
node->SetAttr("strides", strides_attr);
}

void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* split_attr = cn->attrs.as<SplitAttrs>();
ICHECK(split_attr);

std::vector<std::string> indices_or_sections;
std::vector<std::string> mode;
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
mode.emplace_back("sections");
indices_or_sections.emplace_back(std::to_string(sections->value));
} else {
mode.emplace_back("indices");
auto indices = Downcast<tvm::Array<Integer>>(split_attr->indices_or_sections);
for (const auto& i : indices) {
indices_or_sections.emplace_back(std::to_string(i->value));
}
}

std::vector<dmlc::any> indices_or_sections_attr;
std::vector<dmlc::any> mode_attr;
std::vector<dmlc::any> axis_attr;
indices_or_sections_attr.emplace_back(indices_or_sections);
mode_attr.emplace_back(mode);
axis_attr.emplace_back(axis);
node->SetAttr("indices_or_sections", indices_or_sections_attr);
node->SetAttr("mode", mode_attr);
node->SetAttr("axis", axis_attr);
}

void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
Expand Down
62 changes: 59 additions & 3 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,53 @@ class ConcatOpConverter : public TensorRTOpConverter {
}
};

class SplitOpConverter : public TensorRTOpConverter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add #if TRT_VERSION_GE(5, 1, 5) before it because the code uses params->network->addSlice.

public:
SplitOpConverter() : TensorRTOpConverter({kTensor}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto input = params->inputs.at(0).tensor;
auto input_dims = TrtDimsToVector(input->getDimensions());
const int original_axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
const int axis = ConvertAxis(params, original_axis, input_dims.size());
auto indices_or_sections =
params->node.GetAttr<std::vector<std::string>>("indices_or_sections");
auto mode = params->node.GetAttr<std::vector<std::string>>("mode")[0];

std::vector<int> split_starts;
std::vector<int> split_sizes;
if (mode == "sections") {
int sections = std::stoi(indices_or_sections[0]);
int size = input_dims[axis] / sections;
for (int i = 0; i < sections; i++) {
split_starts.push_back(i * size);
split_sizes.push_back(size);
}
} else {
int last_index = 0;
for (size_t i = 0; i < indices_or_sections.size(); ++i) {
int index = std::stoi(indices_or_sections[i]);
split_starts.push_back(last_index);
split_sizes.push_back(index - last_index);
last_index = index;
}
split_starts.push_back(last_index);
split_sizes.push_back(input_dims[axis] - last_index);
}

std::vector<int> start(input_dims.size(), 0);
std::vector<int> size(input_dims.begin(), input_dims.end());
std::vector<int> strides(input_dims.size(), 1);
for (int i = 0; i < split_sizes.size(); ++i) {
start[axis] = split_starts[i];
size[axis] = split_sizes[i];
auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start),
VectorToTrtDims(size), VectorToTrtDims(strides));
params->outputs.push_back(slice_layer->getOutput(0));
}
}
};

class BiasAddOpConverter : public TensorRTOpConverter {
public:
BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
Expand Down Expand Up @@ -970,9 +1017,17 @@ class ReduceOpConverter : public TensorRTOpConverter {
// TODO(trevmorr): Support reduce to scalar.
ICHECK_GT(str_axis.size(), 0);
uint32_t reduce_axes = 0;
for (size_t i = 0; i < str_axis.size(); ++i) {
const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
reduce_axes |= 1 << axis;

if (str_axis.size() == 1 && str_axis[0].length() == 0) {
// Reduce to scalar
for (int i = 0; i < input->getDimensions().nbDims; ++i) {
reduce_axes |= 1 << i;
}
} else {
for (size_t i = 0; i < str_axis.size(); ++i) {
const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
reduce_axes |= 1 << axis;
}
}
auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims);
params->outputs.push_back(reduce_layer->getOutput(0));
Expand Down Expand Up @@ -1072,6 +1127,7 @@ GetOpConverters() {
map->emplace("expand_dims", std::make_shared<ExpandDimsOpConverter>());
map->emplace("squeeze", std::make_shared<SqueezeOpConverter>());
map->emplace("concatenate", std::make_shared<ConcatOpConverter>());
map->emplace("split", std::make_shared<SplitOpConverter>());
Copy link
Contributor

@apivovarov apivovarov Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should go inside #if TRT_VERSION_GE(5, 1, 5) block below

map->emplace("nn.conv2d_transpose", std::make_shared<Conv2DTransposeOpConverter>());
map->emplace("transpose", std::make_shared<TransposeOpConverter>());
map->emplace("layout_transform", std::make_shared<LayoutTransformOpConverter>());
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
* do nothing.
*/
void BuildEngine() {
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
batch_size_ =
data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0];
if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
Expand Down
13 changes: 13 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,19 @@ def get_graph(input_shapes, axis):
run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1))


def test_split():
def get_graph(x_shape, indices_or_sections, axis):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis)
f = relay.Function([x], out.astuple())
return f, {"x": x_shape}, []

run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1))


def test_conv2d_transpose():
def get_graph(
x_shape=(1, 32, 8, 8),
Expand Down