Skip to content

Commit

Permalink
fix: Allow full model compilation with collection Inputs
Browse files Browse the repository at this point in the history
- Allow users to specify full model compilation when using
`input_signature`, which allows for complex collection-based inputs
- Enable "psuedo-partitioning" phase for input collections as well as
output collections
- Update `OutputIsCollection` to include dictionary outputs, and add
function `InputIsCollection` to detect collection-based inputs during
graph compilation
- Remove automatic fallback for collection pack/unpack operations when
using `input_signature` argument
- Add collections tests to ensure full compilation is respected for
input and output collections
  • Loading branch information
gs-olive committed Mar 14, 2023
1 parent fce0a01 commit 71ac294
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 59 deletions.
3 changes: 2 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Determine if the block is convertible/has collection output, and based on the result,
// whether full compilation can be expected
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto inputIsCollection = conversion::InputIsCollection(g->block());
auto outputIsCollection = conversion::OutputIsCollection(g->block());
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));

// Determine whether user specifications necessitate partitioning
auto isFallbackRequested = userRequestedFallback(cfg);
Expand Down
12 changes: 11 additions & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool InputIsCollection(const torch::jit::Block* b) {
for (auto in : b->inputs()) {
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool OutputIsCollection(const torch::jit::Block* b) {
for (auto out : b->outputs()) {
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
out->type()->kind() == torch::jit::TypeKind::ListType) {
out->type()->kind() == torch::jit::TypeKind::ListType ||
out->type()->kind() == torch::jit::TypeKind::DictType) {
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool InputIsCollection(const torch::jit::Block* b);

bool OutputIsCollection(const torch::jit::Block* b);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);
Expand Down
21 changes: 0 additions & 21 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change");
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
torchtrt::core::CompileSpec internal(converted_input_signature);

TORCHTRT_CHECK(
!external.require_full_compilation,
"Grouped inputs currently requires partial compilation to be enabled, \
this restriction will be relaxed in a future release");

LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
LOG_DEBUG(
"Adding the following ops to torch_executed_ops:" << std::endl
<< " - aten::__getitem__" << std::endl
<< " - prim::ListConstruct" << std::endl
<< " - prim::ListUnpack" << std::endl
<< " - prim::TupleIndex" << std::endl
<< " - prim::TupleConstruct" << std::endl
<< " - prim::TupleUnpack");
external.torch_executed_ops.push_back("aten::__getitem__");
external.torch_executed_ops.push_back("prim::ListConstruct");
external.torch_executed_ops.push_back("prim::ListUnpack");
external.torch_executed_ops.push_back("prim::TupleIndex");
external.torch_executed_ops.push_back("prim::TupleConstruct");
external.torch_executed_ops.push_back("prim::TupleUnpack");
return internal;
}
}
Expand Down
37 changes: 1 addition & 36 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
"Input signature parsing is an experimental feature, behavior and APIs may change",
)
signature = _parse_input_signature(compile_spec["input_signature"])
info.input_signature = _C.InputSignature(signature) # py_object

if not compile_spec["torch_fallback"]["enabled"]:
raise ValueError(
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
)

log(
Level.Debug,
"Grouped inputs currently requires additional settings to enable the feature",
)
log(
Level.Debug,
"""Adding the following ops to torch_executed_ops:
- aten::__getitem__
- prim::ListConstruct
- prim::ListUnpack
- prim::TupleIndex
- prim::TupleConstruct
- prim::TupleUnpack
""",
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"aten::__getitem__"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::ListConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleUnpack"
)
info.input_signature = _C.InputSignature(signature)

else:
raise KeyError(
Expand Down
62 changes: 62 additions & 0 deletions tests/cpp/test_collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) {
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
}

TEST(CppAPITests, TestCollectionFullCompilationComplexModel) {
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);

torch::jit::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
}
mod.eval();
mod.to(torch::kCUDA);

std::vector<torch::jit::IValue> inputs_;

for (auto in : inputs) {
inputs_.push_back(torch::jit::IValue(in.clone()));
}

std::vector<torch::jit::IValue> complex_inputs;
auto input_list = c10::impl::GenericList(c10::TensorType::get());
input_list.push_back(inputs_[0]);
input_list.push_back(inputs_[0]);

torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);

complex_inputs.push_back(input_list_ivalue);

auto out = mod.forward(complex_inputs);

auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);

auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));

c10::TypePtr elementType = input_shape_ivalue.type();
auto list = c10::impl::GenericList(elementType);
list.push_back(input_shape_ivalue);
list.push_back(input_shape_ivalue);

torch::jit::IValue complex_input_shape(list);
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
torch::jit::IValue complex_input_shape2(input_tuple2);

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.min_block_size = 1;
compile_settings.require_full_compilation = true;

// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
// // Compile module
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
auto trt_out = trt_mod.forward(complex_inputs);

ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor()));
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
}
87 changes: 87 additions & 0 deletions tests/py/api/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@ def test_compile(self):
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
(torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))
for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


class TestListInputOutput(unittest.TestCase):
def test_compile(self):
Expand Down Expand Up @@ -225,6 +253,36 @@ def test_compile(self):
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))

for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


class TestListInputTupleOutput(unittest.TestCase):
def test_compile(self):
Expand Down Expand Up @@ -255,6 +313,35 @@ def test_compile(self):
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))
for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 71ac294

Please sign in to comment.