From 0586bd4db84037e6bbddb806a9d90eb47f75c118 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Fri, 30 Aug 2024 10:14:01 +0800 Subject: [PATCH] =?UTF-8?q?enable=20hook=20for=20sample=5Finputs=5Findex?= =?UTF-8?q?=5Fput=5Fnofp64=20and=20reference=5Finputs=5Fc=E2=80=A6=20(#846?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The original hook for sample_inputs_index_put_nofp64 and reference_inputs_cat seems not stable, enable them on op_db to improve. Co-authored-by: Zhong, Ruijie --- test/xpu/xpu_test_utils.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index ca53831f5..92435c9b6 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -570,7 +570,7 @@ def convert_dtype(obj, dtype, requires_grad=False): CriterionTest.test_cuda = CriterionTest_test_xpu from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat, S, M -from torch.testing._internal.common_methods_invocations import make_tensor +from torch.testing._internal.common_methods_invocations import make_tensor, mask_not_all_zeros from functools import partial from torch.testing._internal.opinfo.core import SampleInput @@ -604,6 +604,21 @@ def index_variable_nofp64(shape, max_indices, device=torch.device('cpu')): index = torch.rand(*shape, dtype=torch.float32, device=device).mul_(max_indices).floor_().long() return index +def sample_inputs_index_put_nofp64(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for accumulate in [False, True]: + # Test with indices arg + yield SampleInput( + make_arg((S, S,)), + (index_variable_nofp64(2, S, device=device),), + make_arg((2, S)), + accumulate=accumulate) + + # Test with mask arg + mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) + yield SampleInput( + make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) def sample_inputs_softmax_variant_nofp64( op_info, @@ -695,9 +710,6 @@ def __init__(self, patch_test_case=True) -> None: self.cuda_is_available = cuda.is_available self.cuda_is_bf16_supported = cuda.is_bf16_supported - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - self.index_variable = common_methods_invocations.index_variable - self.reference_inputs_cat = common_methods_invocations.reference_inputs_cat def align_db_decorators(self, db): def gen_xpu_wrappers(op_name, wrappers): @@ -774,18 +786,17 @@ def filter_fp64_sample_input(self, db): opinfo.sample_inputs_func = sample_inputs_softmax_variant_nofp64 elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_like_fns.__name__: opinfo.sample_inputs_func = sample_inputs_like_fns_nofp64 + elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_index_put.__name__: + opinfo.sample_inputs_func = sample_inputs_index_put_nofp64 - + if opinfo.reference_inputs_func != None and opinfo.reference_inputs_func.__name__ == common_methods_invocations.reference_inputs_cat.__name__: + opinfo.reference_inputs_func = reference_inputs_cat_nofp64 def __enter__(self): # Monkey patch until we have a fancy way common_device_type.onlyCUDA = common_device_type.onlyXPU - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.index_variable = index_variable_nofp64 - common_methods_invocations.reference_inputs_cat = reference_inputs_cat_nofp64 - class dtypesIfXPU(common_device_type.dtypes): def __init__(self, *args): super().__init__(*args, device_type="xpu") @@ -909,10 +920,6 @@ def __exit__(self, exc_type, exc_value, traceback): cuda.is_available = self.cuda_is_available cuda.is_bf16_supported = self.cuda_is_bf16_supported - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.index_variable = self.index_variable - common_methods_invocations.reference_inputs_cat = self.reference_inputs_cat - # Copy the test cases from generic_base_class to generic_test_class. # It serves to reuse test cases. Regarding some newly added hardware,