From 9a1726a4193d1ea0dfd6dd5ad13e22c17ba96dcb Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Wed, 30 Aug 2023 15:13:43 +0200 Subject: [PATCH] [ovc] check if input is correct in split_inputs (#19350) --- tools/ovc/openvino/tools/ovc/cli_parser.py | 5 +++ tools/ovc/openvino/tools/ovc/convert.py | 2 +- .../unit_tests/ovc/utils/cli_parser_test.py | 32 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tools/ovc/openvino/tools/ovc/cli_parser.py b/tools/ovc/openvino/tools/ovc/cli_parser.py index 2cd0ad96943a50..8568c12a47a53f 100644 --- a/tools/ovc/openvino/tools/ovc/cli_parser.py +++ b/tools/ovc/openvino/tools/ovc/cli_parser.py @@ -609,6 +609,11 @@ def parse_input_value(input_value: str): def split_inputs(input_str): + pattern = r'^(?:[^[\]()<]*(\[[\.)-9,\-\s?]*\])*,)*[^[\]()<]*(\[[\.0-9,\-\s?]*\])*$' + if not re.match(pattern, input_str): + raise Error(f"input value '{input_str}' is incorrect. Input should be in the following format: " + f"{get_convert_model_help_specifics()['input']['description']}") + brakets_count = 0 inputs = [] while input_str: diff --git a/tools/ovc/openvino/tools/ovc/convert.py b/tools/ovc/openvino/tools/ovc/convert.py index d2b008adc63142..509761193b446b 100644 --- a/tools/ovc/openvino/tools/ovc/convert.py +++ b/tools/ovc/openvino/tools/ovc/convert.py @@ -15,7 +15,7 @@ def convert_model( input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments # Framework-agnostic parameters - input: [list, dict] = None, + input: [list, dict, str] = None, output: [str, list] = None, example_input: Any = None, extension: [str, pathlib.Path, list, Any] = None, diff --git a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py index 506f968c9a3124..81696a2b098079 100644 --- a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py +++ b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py @@ -31,6 +31,38 @@ def test_get_shapes_several_inputs_several_shapes2(self): _InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))] self.assertEqual(inputs, inputs_ref) + def test_raises_get_shapes_1(self): + argv_input = "[h,y]" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_raises_get_shapes_2(self): + argv_input = "(2, 3)" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_raises_get_shapes_3(self): + argv_input = "input_1(2, 3)" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_raises_get_shapes_4(self): + argv_input = "(2, 3),(10, 10)" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_raises_get_shapes_5(self): + argv_input = "<2,3,4>" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_raises_get_shapes_6(self): + argv_input = "sd<2,3>" + self.assertRaises(Error, input_to_input_cut_info, argv_input) + + def test_get_shapes_complex_input(self): + argv_input = "[10, -1, 100],mask[],[?,?]" + inputs = input_to_input_cut_info(argv_input) + inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])), + _InputCutInfo(name='mask', shape=PartialShape([])), + _InputCutInfo(shape=PartialShape([-1, -1]))] + self.assertEqual(inputs, inputs_ref) + def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1,inp2"