Skip to content

Commit

Permalink
Move gen_sparse_dataset function from basic statistic test to sklearn…
Browse files Browse the repository at this point in the history
…ex/tests/utils/base.py
  • Loading branch information
Vika-F committed Jan 17, 2025
1 parent efb3aa1 commit 09ecb3a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
16 changes: 5 additions & 11 deletions sklearnex/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@
)
from sklearnex import config_context
from sklearnex.basic_statistics import BasicStatistics
from sklearnex.tests.utils import gen_sparse_dataset


# Generate random sparse data using scipy.sparse.random or scipy.sparse.random_array
def gen_sparse_data(row_count, column_count, **kwargs):
if hasattr(sp, "random_array"):
return sp.random_array((row_count, column_count), **kwargs)
else:
return sp.random(row_count, column_count, **kwargs)


# Compute the basic statistics on sparse data on CPU or GPU depending on the queue
def compute_sparse_result(X_sparse, options, queue):
if queue is not None and queue.sycl_device.is_gpu:
with config_context(target_offload="gpu"):
Expand Down Expand Up @@ -168,7 +162,7 @@ def test_single_option_on_random_sparse_data(

gen = np.random.default_rng(seed)

X_sparse = gen_sparse_data(
X_sparse = gen_sparse_dataset(
row_count,
column_count,
density=0.01,
Expand Down Expand Up @@ -243,7 +237,7 @@ def test_multiple_options_on_random_sparse_data(queue, row_count, column_count,

gen = np.random.default_rng(seed)

X_sparse = gen_sparse_data(
X_sparse = gen_sparse_dataset(
row_count,
column_count,
density=0.05,
Expand Down Expand Up @@ -324,7 +318,7 @@ def test_all_option_on_random_sparse_data(queue, row_count, column_count, dtype)

gen = np.random.default_rng(seed)

X_sparse = gen_sparse_data(
X_sparse = gen_sparse_dataset(
row_count,
column_count,
density=0.05,
Expand Down
2 changes: 2 additions & 0 deletions sklearnex/tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_get_processor_info,
call_method,
gen_dataset,
gen_sparse_dataset,
gen_models_info,
sklearn_clone_dict,
)
Expand All @@ -39,6 +40,7 @@
"call_method",
"gen_models_info",
"gen_dataset",
"gen_sparse_dataset",
"sklearn_clone_dict",
"DummyEstimator",
]
Expand Down
21 changes: 21 additions & 0 deletions sklearnex/tests/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ def gen_dataset(
return output


def gen_sparse_dataset(row_count, column_count, **kwargs):
"""Generate sparse dataset for pytest testing.
Parameters
----------
row_count : number of rows in dataset
column_count: number of columns in dataset
kwargs: keyword arguments for scipy.sparse.random_array or scipy.sparse.random
Returns
-------
scipy.sparse random matrix or array depending on scipy version
"""
if hasattr(sp, "random_array"):
return sp.random_array((row_count, column_count), **kwargs)
else:
return sp.random(row_count, column_count, **kwargs)


DTYPES = [
np.int8,
np.int16,
Expand Down

0 comments on commit 09ecb3a

Please sign in to comment.