From 85bcf2187f6d835f988ea85d9daf35c6f75e5024 Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 14 Jul 2021 10:07:40 +0800 Subject: [PATCH 1/3] support more ops for torch 1.9 --- nni/retiarii/operation_def/torch_op_def.py | 19 ++++++++++++++++--- test/ut/retiarii/test_convert_pytorch.py | 3 ++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/nni/retiarii/operation_def/torch_op_def.py b/nni/retiarii/operation_def/torch_op_def.py index bb97069e63..313a5558af 100644 --- a/nni/retiarii/operation_def/torch_op_def.py +++ b/nni/retiarii/operation_def/torch_op_def.py @@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation): def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: # TODO: refactor this part, maybe we can remove the code gen of prim::Constant # TODO: deal with all the types - if self.parameters['type'] == 'None': + if self.parameters['type'] in ['None', 'NoneType']: return f'{output} = None' elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): return f'{output} = {self.parameters["value"]}' @@ -238,7 +238,13 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val ManuallyChooseDef = { 'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')], - 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')] + 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')], + # in v1.9 dtype is supported as input argument for view, but torch script does not support it + 'aten::view': [('size', 'List[int]', 'None')], + # NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed + # torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor + # torch.std(input, unbiased) Tensor + 'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')] } TensorOpExceptions = { @@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation): # NOTE: it is not included in the above aten ops for unkown reason _ori_type_name = ['aten::avg_pool2d'] def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: - return f'{output} = F.avg_pool2d({", ".join(inputs)})' \ No newline at end of file + return f'{output} = F.avg_pool2d({", ".join(inputs)})' + +class AtenDet(PyTorchOperation): + # for torch 1.9 + # NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det + _ori_type_name = ['aten::linalg_det'] + def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: + return f'{output} = torch.det({inputs[0]})' \ No newline at end of file diff --git a/test/ut/retiarii/test_convert_pytorch.py b/test/ut/retiarii/test_convert_pytorch.py index dbcf1acd31..083778ecf1 100644 --- a/test/ut/retiarii/test_convert_pytorch.py +++ b/test/ut/retiarii/test_convert_pytorch.py @@ -375,7 +375,8 @@ def forward(self, input): # NOTE: torch script gets an incorrect graph... def test_optional_inputs_with_mixed_optionals(self): class MixedModel(nn.Module): - def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'): + #def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'): + def forward(self, x, y, z): # NOTE: torch 1.9 does not support the type string 'Tensor' if y is not None: return x + y if z is not None: From 6a55dd5cb3c6590f060d3e820f2b3c72b738d96c Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 14 Jul 2021 10:12:40 +0800 Subject: [PATCH 2/3] update doc --- docs/en_US/NAS/QuickStart.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/NAS/QuickStart.rst b/docs/en_US/NAS/QuickStart.rst index 01ec33f540..3dc3670ad4 100644 --- a/docs/en_US/NAS/QuickStart.rst +++ b/docs/en_US/NAS/QuickStart.rst @@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__. -.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 and 1.7**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan. +.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 to 1.9**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan. Define your Model Space ----------------------- From 256c64e72e68c3c2d8eb181a6a075052bc3afa6e Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 14 Jul 2021 10:14:38 +0800 Subject: [PATCH 3/3] minor --- test/ut/retiarii/test_convert_pytorch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ut/retiarii/test_convert_pytorch.py b/test/ut/retiarii/test_convert_pytorch.py index 083778ecf1..51857b6815 100644 --- a/test/ut/retiarii/test_convert_pytorch.py +++ b/test/ut/retiarii/test_convert_pytorch.py @@ -375,8 +375,7 @@ def forward(self, input): # NOTE: torch script gets an incorrect graph... def test_optional_inputs_with_mixed_optionals(self): class MixedModel(nn.Module): - #def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'): - def forward(self, x, y, z): # NOTE: torch 1.9 does not support the type string 'Tensor' + def forward(self, x, y, z): if y is not None: return x + y if z is not None: