Skip to content

Commit

Permalink
[PT FE] Partially disable freezing for int8 and uint8 weights (#18827)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin authored Jul 28, 2023
1 parent 9bf5b6e commit 481721e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
19 changes: 15 additions & 4 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import numpy as np

wrapper_template="""
wrapper_template = """
import torch
from typing import *
Expand All @@ -27,6 +27,7 @@ def forward(self, {input_sign}):
return self.model({example_input})
"""


class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None):
Decoder.__init__(self)
Expand Down Expand Up @@ -64,6 +65,15 @@ def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=N
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
self._transform_optional_constants(self.graph_element)

@staticmethod
def _get_preserved_attributes(model) -> list:
preserved_attributes = []
for name, module in model.named_modules():
if hasattr(module, "weight"):
if module.weight.dtype in [torch.int8, torch.uint8]:
preserved_attributes.append(name)
return preserved_attributes

def _get_scripted_model(self, pt_module, example_inputs=None):
import torch
import inspect
Expand Down Expand Up @@ -156,12 +166,13 @@ def prepare_example_inputs_and_model(inputs, input_params, model):
first_input = next(n.inputs())
if first_input.node().kind() == "prim::Constant":
ivalue = first_input.toIValue()
if ivalue is not None and ivalue.dtype in [torch.uint8, torch.int8, torch.bfloat16, torch.float16]:
if ivalue is not None and ivalue.dtype in [torch.bfloat16, torch.float16]:
# do not freeze models with compressed constants
skip_freeze = True
break
if not skip_freeze:
f_model = torch.jit.freeze(scripted)
preserved_attrs = self._get_preserved_attributes(scripted)
f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs)
else:
f_model = scripted
else:
Expand Down Expand Up @@ -493,4 +504,4 @@ def _transform_optional_constants(graph: torch.Graph):
const_input = graph.insertConstant(value)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
node.output().replaceAllUsesWith(const_input)
node.output().replaceAllUsesWith(const_input)
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ class TorchDecoder : public IDecoder {

/// Returns new nodes for inputs inlined in the op itself
// Used in Torch.FX decoder
virtual OutputVector inlined_inputs(size_t start_index) const {
return {};
}
virtual OutputVector inlined_inputs(size_t start_index) const = 0;
};

} // namespace pytorch
Expand Down
13 changes: 5 additions & 8 deletions src/frontends/pytorch/src/translate_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,11 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// TODO: Eliminate duplication with the main code for Parameters creation
PartialShape ps = node->get_input_shape(i);
auto type = simplified_type_interpret(node->get_input_type(i));
// TODO: Use special API to set custom type specification
std::shared_ptr<v0::Parameter> parameter;
// TODO: Use decoder type or explore adding the missing cast types to Torchscript path
const char* torch_tracing_mode = std::getenv("PYTORCH_TRACING_MODE");
if ((torch_tracing_mode != nullptr) && std::strcmp(torch_tracing_mode, "TORCHFX") == 0)
parameter = std::make_shared<v0::Parameter>(type.as<element::Type>(), ps);
else
parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
auto dtype = element::dynamic;
if (type.is<element::Type>()) {
dtype = type.as<element::Type>();
}
auto parameter = std::make_shared<v0::Parameter>(dtype, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts
(*tensor_map)[input] = parameter;
// set name of parameter to the index of node in the model
Expand Down

0 comments on commit 481721e

Please sign in to comment.