Skip to content

Commit

Permalink
load using FE
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 18, 2021
1 parent 5661ae5 commit a2be8a9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def add_openvino_specific_args(parser):
type=cast_to_bool,
required=False
)
openvino_specific_args.add_argument(
'--model_type',
help='model format for automatic search (e.g. blob, xml, onnx)',
required=False
)
openvino_specific_args.add_argument(
'-C', '--converted_models',
help='directory to store Model Optimizer converted models. Used for DLSDK launcher only',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, config_entry, model_name='', delayed_model_loading=False,
self._model, self._weights = automatic_model_search(
self._model_name, self.get_value_from_config('model'),
self.get_value_from_config('weights'),
self.get_value_from_config('_model_is_blob')
self.get_value_from_config('_model_type')
)
self.load_network(log=True, preprocessing=preprocessor)
self.allow_reshape_input = self.get_value_from_config('allow_reshape_input') and self.network is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ def check_model_source(entry, fetch_only=False, field_uri=None, validation_schem
'default - try to run as default, if does not work switch to static, '
'dynamic - enforce network execution with dynamic shapes, '
'static - convert undefined shapes to static before execution'
)
),
'_model_type': StringField(
choices=['xml', 'blob', 'onnx', 'paddle', 'tf'],
description='hint for model type in automatic model search')
}


Expand Down Expand Up @@ -371,41 +374,36 @@ def mo_convert_model(config, launcher_parameters, framework=None):
)


def automatic_model_search(model_name, model_cfg, weights_cfg, model_is_blob):
def get_xml(model_dir):
models_list = list(model_dir.glob('{}.xml'.format(model_name)))
if not models_list:
models_list = list(model_dir.glob('*.xml'))
return models_list

def get_blob(model_dir):
blobs_list = list(Path(model_dir).glob('{}.blob'.format(model_name)))
if not blobs_list:
blobs_list = list(Path(model_dir).glob('*.blob'))
return blobs_list

def get_onnx(model_dir):
onnx_list = list(Path(model_dir).glob('{}.onnx'.format(model_name)))
if not onnx_list:
onnx_list = list(Path(model_dir).glob('*.onnx'))
return onnx_list
def automatic_model_search(model_name, model_cfg, weights_cfg, model_type=None):
model_type_ext = {
'xml': 'xml',
'blob': 'blob',
'onnx': 'onnx',
'paddle': 'pdmodel',
'tf': 'pb'
}
def get_model_by_suffix(model_name, model_dir, suffix):
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
if not model_list:
model_list = list(Path(model_dir).glob('*.{}'.format(suffix)))
return model_list

def get_model():
model = Path(model_cfg)
if not model.is_dir():
accepted_suffixes = ['.blob', '.onnx', '.xml']
accepted_suffixes = list(model_type_ext.values())
if model.suffix not in accepted_suffixes:
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
print_info('Found model {}'.format(model))
return model, model.suffix == '.blob'
if model_is_blob:
model_list = get_blob(model)
model_list = []
if model_type is not None:
model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type])
else:
model_list = get_xml(model)
if not model_list and model_is_blob is None:
model_list = get_blob(model)
if not model_list:
model_list = get_onnx(model)
for ext in model_type_ext.values():
model_list = get_model_by_suffix(model_name, model, ext)
if model_list:
break
if not model_list:
raise ConfigError('suitable model is not found')
if len(model_list) != 1:
Expand All @@ -418,7 +416,7 @@ def get_model():
if is_blob:
return model, None
weights = weights_cfg
if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx':
if (weights is None or Path(weights).is_dir()) and model.suffix == '.xml':
weights_dir = weights or model.parent
weights = Path(weights_dir) / model.name.replace('xml', 'bin')
if weights is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, config_entry, model_name='', delayed_model_loading=False,
self._model, self._weights = automatic_model_search(
self._model_name, self.get_value_from_config('model'),
self.get_value_from_config('weights'),
self.get_value_from_config('_model_is_blob')
self.get_value_from_config('_model_type')
)
self.load_network(log=True, preprocessing=preprocessor)
self.allow_reshape_input = self.get_value_from_config('allow_reshape_input') and self.network is not None
Expand Down

0 comments on commit a2be8a9

Please sign in to comment.