diff --git a/core/compiler.cpp b/core/compiler.cpp index 3dd735a59e..bda4583664 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) // Determine if the block is convertible/has collection output, and based on the result, // whether full compilation can be expected auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); + auto inputIsCollection = conversion::InputIsCollection(g->block()); auto outputIsCollection = conversion::OutputIsCollection(g->block()); - auto requires_collection_handling = (isBlockConvertible && outputIsCollection); + auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection)); // Determine whether user specifications necessitate partitioning auto isFallbackRequested = userRequestedFallback(cfg); diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 940e178850..b0e8174500 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -556,10 +556,20 @@ std::set ConvertableOpsInBlock(const torch::jit::Block* b) { return convertable_ops; } +bool InputIsCollection(const torch::jit::Block* b) { + for (auto in : b->inputs()) { + if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) { + return true; + } + } + return false; +} + bool OutputIsCollection(const torch::jit::Block* b) { for (auto out : b->outputs()) { if (out->type()->kind() == torch::jit::TypeKind::TupleType || - out->type()->kind() == torch::jit::TypeKind::ListType) { + out->type()->kind() == torch::jit::TypeKind::ListType || + out->type()->kind() == torch::jit::TypeKind::DictType) { return true; } } diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index a578c4288e..4ef092a1be 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -26,6 +26,8 @@ std::string ConvertBlockToEngine( bool OpSupported(const torch::jit::Node* n); +bool InputIsCollection(const torch::jit::Block* b); + bool OutputIsCollection(const torch::jit::Block* b); bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false); diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 0fe56265e7..1954827893 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) { LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change"); to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature); torchtrt::core::CompileSpec internal(converted_input_signature); - - TORCHTRT_CHECK( - !external.require_full_compilation, - "Grouped inputs currently requires partial compilation to be enabled, \ - this restriction will be relaxed in a future release"); - - LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature"); - LOG_DEBUG( - "Adding the following ops to torch_executed_ops:" << std::endl - << " - aten::__getitem__" << std::endl - << " - prim::ListConstruct" << std::endl - << " - prim::ListUnpack" << std::endl - << " - prim::TupleIndex" << std::endl - << " - prim::TupleConstruct" << std::endl - << " - prim::TupleUnpack"); - external.torch_executed_ops.push_back("aten::__getitem__"); - external.torch_executed_ops.push_back("prim::ListConstruct"); - external.torch_executed_ops.push_back("prim::ListUnpack"); - external.torch_executed_ops.push_back("prim::TupleIndex"); - external.torch_executed_ops.push_back("prim::TupleConstruct"); - external.torch_executed_ops.push_back("prim::TupleUnpack"); return internal; } } diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 0e11d3bcd3..8f06e2ef71 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: "Input signature parsing is an experimental feature, behavior and APIs may change", ) signature = _parse_input_signature(compile_spec["input_signature"]) - info.input_signature = _C.InputSignature(signature) # py_object - - if not compile_spec["torch_fallback"]["enabled"]: - raise ValueError( - "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" - ) - - log( - Level.Debug, - "Grouped inputs currently requires additional settings to enable the feature", - ) - log( - Level.Debug, - """Adding the following ops to torch_executed_ops: - - aten::__getitem__ - - prim::ListConstruct - - prim::ListUnpack - - prim::TupleIndex - - prim::TupleConstruct - - prim::TupleUnpack -""", - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "aten::__getitem__" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::ListConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleUnpack" - ) + info.input_signature = _C.InputSignature(signature) else: raise KeyError( diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index cbca9c7b98..943119977b 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) { ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor())); } + +TEST(CppAPITests, TestCollectionFullCompilationComplexModel) { + std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + compile_settings.require_full_compilation = true; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor())); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor())); +} diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 12c1ac9f50..64f46fa3e9 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -194,6 +194,34 @@ def test_compile(self): msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + (torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)), + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + class TestListInputOutput(unittest.TestCase): def test_compile(self): @@ -225,6 +253,36 @@ def test_compile(self): msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + [torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)], + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + class TestListInputTupleOutput(unittest.TestCase): def test_compile(self): @@ -255,6 +313,35 @@ def test_compile(self): msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + [torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)], + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + if __name__ == "__main__": unittest.main()