diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index b924845ce5a736..ce8c60c864dbaf 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -14,7 +14,7 @@ import torch import numpy as np -wrapper_template=""" +wrapper_template = """ import torch from typing import * @@ -27,6 +27,7 @@ def forward(self, {input_sign}): return self.model({example_input}) """ + class TorchScriptPythonDecoder (Decoder): def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None): Decoder.__init__(self) @@ -64,6 +65,15 @@ def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=N self._transform_tensor_list_constants_to_listconstruct(self.graph_element) self._transform_optional_constants(self.graph_element) + @staticmethod + def _get_preserved_attributes(model) -> list: + preserved_attributes = [] + for name, module in model.named_modules(): + if hasattr(module, "weight"): + if module.weight.dtype in [torch.int8, torch.uint8]: + preserved_attributes.append(name) + return preserved_attributes + def _get_scripted_model(self, pt_module, example_inputs=None): import torch import inspect @@ -156,12 +166,13 @@ def prepare_example_inputs_and_model(inputs, input_params, model): first_input = next(n.inputs()) if first_input.node().kind() == "prim::Constant": ivalue = first_input.toIValue() - if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]: + if ivalue is not None and ivalue.dtype in [torch.bfloat16, torch.float16]: # do not freeze models with compressed constants skip_freeze = True break if not skip_freeze: - f_model = torch.jit.freeze(scripted) + preserved_attrs = self._get_preserved_attributes(scripted) + f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs) else: f_model = scripted else: @@ -493,4 +504,4 @@ def _transform_optional_constants(graph: torch.Graph): const_input = graph.insertConstant(value) const_input.node().moveBefore(node) const_input.node().copyMetadata(node) - node.output().replaceAllUsesWith(const_input) \ No newline at end of file + node.output().replaceAllUsesWith(const_input) diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index da3c76544b47b3..21517bea278555 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -112,9 +112,7 @@ class TorchDecoder : public IDecoder { /// Returns new nodes for inputs inlined in the op itself // Used in Torch.FX decoder - virtual OutputVector inlined_inputs(size_t start_index) const { - return {}; - } + virtual OutputVector inlined_inputs(size_t start_index) const = 0; }; } // namespace pytorch diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index b9335fa6ac1121..894b6bd3f15c20 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -162,14 +162,11 @@ std::shared_ptr TranslateSession::convert_pytorch_model( // TODO: Eliminate duplication with the main code for Parameters creation PartialShape ps = node->get_input_shape(i); auto type = simplified_type_interpret(node->get_input_type(i)); - // TODO: Use special API to set custom type specification - std::shared_ptr parameter; - // TODO: Use decoder type or explore adding the missing cast types to Torchscript path - const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE"); - if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0) - parameter = std::make_shared(type.as(), ps); - else - parameter = std::make_shared(element::dynamic, ps); + auto dtype = element::dynamic; + if (type.is()) { + dtype = type.as(); + } + auto parameter = std::make_shared(dtype, ps); // TODO: Missing get_input_transpose_order handling for not trivial layouts (*tensor_map)[input] = parameter; // set name of parameter to the index of node in the model