diff --git a/sklearnex/utils/tests/test_validation.py b/sklearnex/utils/tests/test_validation.py index 70da28dbce..c770abd495 100644 --- a/sklearnex/utils/tests/test_validation.py +++ b/sklearnex/utils/tests/test_validation.py @@ -1,5 +1,5 @@ # ============================================================================== -# Copyright 2024 Intel Corporation +# Copyright 2024 UXL Foundation Contributors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/sklearnex/utils/validation.py b/sklearnex/utils/validation.py index c2ba2c1dc5..4e908a31ce 100755 --- a/sklearnex/utils/validation.py +++ b/sklearnex/utils/validation.py @@ -46,7 +46,13 @@ def _onedal_supported_format(X, xp=None): # _onedal_supported_format is therefore conservative in verifying attributes and # does not support array_api. This will block onedal_assert_all_finite from being # used for array_api inputs but will allow dpnp ndarrays and dpctl tensors. - return X.dtype in [xp.float32, xp.float64] and hasattr(X, "flags") + # only check contiguous arrays to prevent unnecessary copying of data, even if + # non-contiguous arrays can now be converted to oneDAL tables. + return ( + X.dtype in [xp.float32, xp.float64] + and hasattr(X, "flags") + and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"]) + ) else: from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite @@ -108,14 +114,37 @@ def validate_data( y=y, **kwargs, ) + + check_x = not isinstance(X, str) or X != "no_validation" + check_y = not (y is None or isinstance(y, str) and y == "no_validation") + if ensure_all_finite: # run local finite check allow_nan = ensure_all_finite == "allow-nan" arg = iter(out if isinstance(out, tuple) else (out,)) - if not isinstance(X, str) or X != "no_validation": + if check_x: assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X") - if not (y is None or isinstance(y, str) and y == "no_validation"): + if check_y: assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y") + + if check_y and "dtype" in kwargs: + # validate_data does not do full dtype conversions, as it uses check_X_y + # oneDAL can make tables from [int32, float32, float64], requiring + # a dtype check and conversion. This will query the array_namespace and + # convert y as necessary. This is done after assert_all_finite, because + # int y arrays do not need to finite check, and this will lead to a speedup + # in comparison to sklearn + dtype = kwargs["dtype"] + if not isinstance(dtype, (tuple, list)): + dtype = tuple(dtype) + + outx, outy = out if check_x else (None, out) + if outy.dtype not in dtype: + yp, _ = get_namespace(outy) + # use asarray rather than astype because of numpy support + outy = yp.asarray(outy, dtype=dtype[0]) + out = (outx, outy) if check_x else outy + return out