Skip to content

Commit

Permalink
[PT FE] Support prim::fork and aten::wait (openvinotoolkit#26839)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `prim::fork` and `aten::wait`*

### Tickets:
 - *CVS-153613*

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
mvafin and mlukasze authored Oct 8, 2024
1 parent 4b43150 commit 737caf5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html."
) from e
self.graph_element = pt_module.inlined_graph
log.debug("Inlined graph:\n%s", pt_module.inlined_graph)
self.alias_db = self.graph_element.alias_db()
else:
self.graph_element = graph_element
Expand All @@ -96,6 +95,7 @@ def __init__(
self._transform_tensor_list_constants_to_listconstruct(
self.graph_element)
self._transform_optional_constants(self.graph_element)
log.debug("Inlined graph:\n%s", self.graph_element)

@staticmethod
def _get_preserved_attributes(model) -> list:
Expand Down Expand Up @@ -293,22 +293,31 @@ def decoder_type_name(self) -> str:
return "ts"

def get_subgraphs(self) -> list:
if self.graph_element.kind() == "prim::PythonOp":
if self.graph_element.kind() in ["prim::PythonOp", "prim::fork"]:
if "Subgraph" in self.graph_element.attributeNames():
assert isinstance(
self.graph_element, torch.Node), "Graph element must be of type torch.Node."
return [getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")]
subgraph = getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")
torch._C._jit_pass_inline(subgraph)
return [subgraph]
else:
# Attribute "Subgraph" is only available if Graph was created using tracing.
# TODO Find way to extract subgraph for scripted Graph.
return []
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(
self.pt_module, self.get_subgraphs(
)[index], alias_db=self.alias_db, shared_memory=self._shared_memory, module_extensions=self.module_extensions
)
module = self.pt_module
if self.graph_element.kind() == "prim::fork":
in0 = self.raw_inputs[0]
if in0.node().kind() == "prim::GetAttr":
module, _ = get_value_from_getattr(in0.node(), self.pt_module)
decoder = TorchScriptPythonDecoder(module,
self.get_subgraphs()[index],
alias_db=self.alias_db,
shared_memory=self._shared_memory,
module_extensions=self.module_extensions
)
self.m_decoders.append(decoder)
return decoder

Expand Down Expand Up @@ -456,8 +465,8 @@ def as_string(self):

@staticmethod
def _as_constant_list(pt_value: torch.Value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
# rewrite them in that part where constant attributes are queried
# For now we treat a list as a 1D tensor; it is required by converters to avoid
# need to massively rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
ivalue = pt_value.toIValue()
is_known_type = pt_element_type in pt_to_ov_type_map
Expand All @@ -467,6 +476,7 @@ def _as_constant_list(pt_value: torch.Value):
ovshape = PartialShape([len(ivalue)])
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()
return []

def _get_device_string(self) -> str:
assert self.graph_element.kind(
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::var_mean", op::translate_var_mean},
{"aten::view", op::quantizable_op<op::translate_reshape>},
{"aten::view_as", op::translate_reshape_as},
{"aten::wait", op::skip_node},
{"aten::where", op::translate_where},
{"aten::zero", op::translate_zeros_like},
{"aten::zeros", op::translate_zeros},
Expand All @@ -685,6 +686,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"prim::Constant", op::translate_constant},
{"prim::device", op::translate_constant},
// prim::DictConstruct - Supported in limited set of patterns
{"prim::fork", op::translate_pythonop},
{"prim::GetAttr", op::translate_get_attr},
{"prim::If", op::translate_if},
{"prim::is_cuda", op::return_false_scalar},
Expand Down
38 changes: 38 additions & 0 deletions tests/layer_tests/pytorch_tests/test_fork_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestForkWait(PytorchLayerTest):

def _prepare_input(self):
return (np.random.randn(10, 20),)

def create_model(self):

class AddMod(torch.nn.Module):
def forward(self, a: torch.Tensor, b: int):
return a + b, a - b

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = AddMod()

def forward(self, input):
fut = torch.jit.fork(self.mod, a=input, b=2)
return torch.jit.wait(fut)

return Mod(), None, ["prim::fork", "aten::wait"]

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("to_trace"), [True, False])
def test_fork_wait(self, to_trace, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=to_trace)

0 comments on commit 737caf5

Please sign in to comment.