diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 28bd6bbd2dfbb0..af8eafda8e9be7 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -73,7 +73,6 @@ def __init__( "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html." ) from e self.graph_element = pt_module.inlined_graph - log.debug("Inlined graph:\n%s", pt_module.inlined_graph) self.alias_db = self.graph_element.alias_db() else: self.graph_element = graph_element @@ -96,6 +95,7 @@ def __init__( self._transform_tensor_list_constants_to_listconstruct( self.graph_element) self._transform_optional_constants(self.graph_element) + log.debug("Inlined graph:\n%s", self.graph_element) @staticmethod def _get_preserved_attributes(model) -> list: @@ -293,11 +293,13 @@ def decoder_type_name(self) -> str: return "ts" def get_subgraphs(self) -> list: - if self.graph_element.kind() == "prim::PythonOp": + if self.graph_element.kind() in ["prim::PythonOp", "prim::fork"]: if "Subgraph" in self.graph_element.attributeNames(): assert isinstance( self.graph_element, torch.Node), "Graph element must be of type torch.Node." - return [getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")] + subgraph = getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph") + torch._C._jit_pass_inline(subgraph) + return [subgraph] else: # Attribute "Subgraph" is only available if Graph was created using tracing. # TODO Find way to extract subgraph for scripted Graph. @@ -305,10 +307,17 @@ def get_subgraphs(self) -> list: return list(self.graph_element.blocks()) def get_subgraph_decoder(self, index: int): - decoder = TorchScriptPythonDecoder( - self.pt_module, self.get_subgraphs( - )[index], alias_db=self.alias_db, shared_memory=self._shared_memory, module_extensions=self.module_extensions - ) + module = self.pt_module + if self.graph_element.kind() == "prim::fork": + in0 = self.raw_inputs[0] + if in0.node().kind() == "prim::GetAttr": + module, _ = get_value_from_getattr(in0.node(), self.pt_module) + decoder = TorchScriptPythonDecoder(module, + self.get_subgraphs()[index], + alias_db=self.alias_db, + shared_memory=self._shared_memory, + module_extensions=self.module_extensions + ) self.m_decoders.append(decoder) return decoder @@ -456,8 +465,8 @@ def as_string(self): @staticmethod def _as_constant_list(pt_value: torch.Value): - # For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively - # rewrite them in that part where constant attributes are queried + # For now we treat a list as a 1D tensor; it is required by converters to avoid + # need to massively rewrite them in that part where constant attributes are queried pt_element_type = str(pt_value.type().getElementType()) ivalue = pt_value.toIValue() is_known_type = pt_element_type in pt_to_ov_type_map @@ -467,6 +476,7 @@ def _as_constant_list(pt_value: torch.Value): ovshape = PartialShape([len(ivalue)]) ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue) return ov_const.outputs() + return [] def _get_device_string(self) -> str: assert self.graph_element.kind( diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 31cf99a2e1b9d7..141e5b02ad8d25 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -675,6 +675,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::var_mean", op::translate_var_mean}, {"aten::view", op::quantizable_op}, {"aten::view_as", op::translate_reshape_as}, + {"aten::wait", op::skip_node}, {"aten::where", op::translate_where}, {"aten::zero", op::translate_zeros_like}, {"aten::zeros", op::translate_zeros}, @@ -685,6 +686,7 @@ const std::unordered_map get_supported_ops_ts() { {"prim::Constant", op::translate_constant}, {"prim::device", op::translate_constant}, // prim::DictConstruct - Supported in limited set of patterns + {"prim::fork", op::translate_pythonop}, {"prim::GetAttr", op::translate_get_attr}, {"prim::If", op::translate_if}, {"prim::is_cuda", op::return_false_scalar}, diff --git a/tests/layer_tests/pytorch_tests/test_fork_wait.py b/tests/layer_tests/pytorch_tests/test_fork_wait.py new file mode 100644 index 00000000000000..577d20aba83afc --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_fork_wait.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestForkWait(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(10, 20),) + + def create_model(self): + + class AddMod(torch.nn.Module): + def forward(self, a: torch.Tensor, b: int): + return a + b, a - b + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = AddMod() + + def forward(self, input): + fut = torch.jit.fork(self.mod, a=input, b=2) + return torch.jit.wait(fut) + + return Mod(), None, ["prim::fork", "aten::wait"] + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("to_trace"), [True, False]) + def test_fork_wait(self, to_trace, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, + ir_version, trace_model=to_trace)