Skip to content

Commit

Permalink
[ovc] check if input is correct in split_inputs (#19350)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir authored Aug 30, 2023
1 parent 02d6c1c commit 9a1726a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tools/ovc/openvino/tools/ovc/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ def parse_input_value(input_value: str):


def split_inputs(input_str):
pattern = r'^(?:[^[\]()<]*(\[[\.)-9,\-\s?]*\])*,)*[^[\]()<]*(\[[\.0-9,\-\s?]*\])*$'
if not re.match(pattern, input_str):
raise Error(f"input value '{input_str}' is incorrect. Input should be in the following format: "
f"{get_convert_model_help_specifics()['input']['description']}")

brakets_count = 0
inputs = []
while input_str:
Expand Down
2 changes: 1 addition & 1 deletion tools/ovc/openvino/tools/ovc/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def convert_model(
input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments

# Framework-agnostic parameters
input: [list, dict] = None,
input: [list, dict, str] = None,
output: [str, list] = None,
example_input: Any = None,
extension: [str, pathlib.Path, list, Any] = None,
Expand Down
32 changes: 32 additions & 0 deletions tools/ovc/unit_tests/ovc/utils/cli_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ def test_get_shapes_several_inputs_several_shapes2(self):
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
self.assertEqual(inputs, inputs_ref)

def test_raises_get_shapes_1(self):
argv_input = "[h,y]"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_raises_get_shapes_2(self):
argv_input = "(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_raises_get_shapes_3(self):
argv_input = "input_1(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_raises_get_shapes_4(self):
argv_input = "(2, 3),(10, 10)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_raises_get_shapes_5(self):
argv_input = "<2,3,4>"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_raises_get_shapes_6(self):
argv_input = "sd<2,3>"
self.assertRaises(Error, input_to_input_cut_info, argv_input)

def test_get_shapes_complex_input(self):
argv_input = "[10, -1, 100],mask[],[?,?]"
inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
_InputCutInfo(name='mask', shape=PartialShape([])),
_InputCutInfo(shape=PartialShape([-1, -1]))]
self.assertEqual(inputs, inputs_ref)

def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
# shapes and value for freezing specified using --input command line parameter
argv_input = "inp1,inp2"
Expand Down

0 comments on commit 9a1726a

Please sign in to comment.