Skip to content

Commit

Permalink
fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
HollowPrincess committed Sep 9, 2024
1 parent 92b42ff commit 24085dd
Showing 1 changed file with 56 additions and 40 deletions.
96 changes: 56 additions & 40 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from torch import nn
from torch.optim.swa_utils import AveragedModel, SWALR

import openvino as ov
import shelve

from sklearn.decomposition import PCA

Expand Down Expand Up @@ -1703,8 +1701,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
path, pickle_module=pickle_module, **kwargs)

elif use_openvino:
if batch_size is None:
raise ValueError('Specify valid `batch_size`, used for model inference!')
import openvino as ov

path_openvino = path_openvino or (path + '_openvino')
if os.path.splitext(path_openvino)[-1] == '':
Expand All @@ -1722,16 +1719,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
# Save the rest of parameters
preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
preserved_dict = {item: getattr(self, item) for item in preserved}
out_path_params = f'{os.path.splitext(path_openvino)[0]}_bf_params_db'

with shelve.open(out_path_params) as params_db:
params_db.update(preserved_dict)
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

else:
torch.save({item: getattr(self, item) for item in self.PRESERVE},
path, pickle_module=pickle_module, **kwargs)

def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
""" Load a torch model from a file.
If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size
Expand All @@ -1741,8 +1736,6 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval',
----------
file : str, PathLike, io.Bytes
a file where a model is stored.
is_openvino : bool, default False
Whether the load file as openvino model instance.
make_infrastructure : bool
Whether to re-create model loss, optimizer, scaler and decay.
mode : str
Expand All @@ -1752,39 +1745,40 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval',
kwargs : dict
Other keyword arguments, passed directly to :func:`torch.save`.
"""
self._parse_devices()
model_load_kwargs = kwargs.pop('model_load_kwargs', {})

if is_openvino:
device = kwargs.pop('device', None) or self.device or 'CPU'
self.device = device.lower()
device = kwargs.pop('device', None)

model = OVModel(model_path=file, device=device, **kwargs)
self.model = model
if device is not None:
self.device = device

# Load params
out_path_params = f'{os.path.splitext(file)[0]}_bf_params_db'
with shelve.open(out_path_params) as params_db:
params = {**params_db}
if (self.device == 'cpu') or ((not isinstance(self.device, str)) and (self.device.type == 'cpu')):
self.amp = False
else:
self._parse_devices()

for key, value in params.items():
setattr(self, key, value)
kwargs['map_location'] = self.device

self._loaded_from_openvino = True
self.disable_training = True
else:
kwargs['map_location'] = self.device if self.device else 'cpu'
# Load items from disk storage and set them as insance attributes
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)

# Load items from disk storage and set them as insance attributes
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)
# `load_config` is a reference to `self.external_config` used to update `config`
# It is required since `self.external_config` may be overwritten in the cycle below
load_config = self.external_config

# `load_config` is a reference to `self.external_config` used to update `config`
# It is required since `self.external_config` may be overwritten in the cycle below
load_config = self.external_config
for key, value in checkpoint.items():
setattr(self, key, value)
self.config = self.config + load_config

for key, value in checkpoint.items():
setattr(self, key, value)
self.config = self.config + load_config
if 'openvino' in checkpoint:
# Load openvino model
model = OVModel(model_path=checkpoint['path_openvino'], **model_load_kwargs)
self.model = model

self._loaded_from_openvino = True
self.disable_training = True

else:
# Load model from onnx, if needed
if 'onnx' in checkpoint:
try:
Expand Down Expand Up @@ -1957,25 +1951,47 @@ def reduce_channels(array, normalize=True, n_components=3):
return compressed_array, explained_variance_ratio

class OVModel:
def __init__(self, model_path, core_config=None, device='CPU', compile_config=None):
""" Class-wrapper for openvino models to interact with them through :class:`~.TorchModel` interface.
Note, openvino models are loaded on 'cpu' only.
Parameters
----------
model_path : str
Path to compiled openvino model.
core_config : tuple or dict, optional
Openvino core properties.
If you want set properties globally provide them as tuple: `('CPU', {name: value})`.
For local properties just provide `{name: value}` dict.
For more, read the documentation:
https://docs.openvino.ai/2023.3/openvino_docs_OV_UG_query_api.html#setting-properties-globally
compile_config : dict, optional
Openvino model compilation config.
"""
def __init__(self, model_path, core_config=None, compile_config=None):
import openvino as ov

core = ov.Core()

if core_config is not None:
for name, kwargs_ in core_config.items():
core.set_property(name, kwargs_)
if isinstance(core_config, tuple):
core.set_property(core_config[0], core_config[1])
else:
core.set_property(core_config)

self.model = core.read_model(model=model_path)

if compile_config is None:
compile_config = {}
self.model = core.compile_model(self.model, device, config=compile_config)

self.model = core.compile_model(self.model, 'CPU', config=compile_config)

def eval(self):
""" Placeholder for compatibility with :class:`~TorchModel` methods."""
pass

def __call__(self, input_tensor):
""" Evaluate model on provided data. """
""" Evaluate model on the provided data. """
results = self.model(input_tensor)

results = torch.from_numpy(results[self.model.output(0)])
Expand Down

0 comments on commit 24085dd

Please sign in to comment.