diff --git a/supirfactor_dynamical/_io/load_data.py b/supirfactor_dynamical/_io/load_data.py index ee11ae9..ee0c79a 100644 --- a/supirfactor_dynamical/_io/load_data.py +++ b/supirfactor_dynamical/_io/load_data.py @@ -156,7 +156,7 @@ def _get_data_from_ad( if densify: try: - _output = _output.A + _output = _output.toarray() except AttributeError: pass diff --git a/supirfactor_dynamical/_utils/_adata_load.py b/supirfactor_dynamical/_utils/_adata_load.py index 6ed77e0..cbc61f0 100644 --- a/supirfactor_dynamical/_utils/_adata_load.py +++ b/supirfactor_dynamical/_utils/_adata_load.py @@ -1,6 +1,6 @@ import anndata as ad import numpy as np - +import scipy.sparse as sps from pandas.api.types import is_float_dtype @@ -112,7 +112,7 @@ def load_data_files_jtb_2023( count_data = count_scaling.fit_transform(count_data) try: - count_data = count_data.A + count_data = count_data.toarray() except AttributeError: pass @@ -226,7 +226,7 @@ def _shuffle_data( return _sim_ints( pvec / pvec.sum(), np.full(pvec.shape, 3099, dtype=int), - sparse=hasattr(x, 'A') + sparse=sps.issparse(x) ) elif sim_velo: diff --git a/supirfactor_dynamical/datasets/time_dataset.py b/supirfactor_dynamical/datasets/time_dataset.py index 349ddd2..f18c8ef 100644 --- a/supirfactor_dynamical/datasets/time_dataset.py +++ b/supirfactor_dynamical/datasets/time_dataset.py @@ -380,8 +380,9 @@ def _get_data(data, idx, keep_sparse=False): if issparse(_data) and not keep_sparse: + _data = _data.toarray() _data = torch.Tensor( - _data.A.reshape(-1) if _data.shape[0] == 1 else _data.A + _data.reshape(-1) if _data.shape[0] == 1 else _data ) return _data diff --git a/supirfactor_dynamical/tests/test_loader.py b/supirfactor_dynamical/tests/test_loader.py index 7f6f1e4..f50351f 100644 --- a/supirfactor_dynamical/tests/test_loader.py +++ b/supirfactor_dynamical/tests/test_loader.py @@ -96,7 +96,7 @@ def test_bulk_sampling(self): self.assertIsNone(td.strat_idxes) try: - x = self.adata.X.A + x = self.adata.X.toarray() except AttributeError: x = self.adata.X diff --git a/supirfactor_dynamical/tests/test_serialize.py b/supirfactor_dynamical/tests/test_serialize.py index 094ee0d..f4c8271 100644 --- a/supirfactor_dynamical/tests/test_serialize.py +++ b/supirfactor_dynamical/tests/test_serialize.py @@ -172,7 +172,7 @@ def test_h5ad_sparse(self): adata2 = _read_ad(self.temp_file_name, "adata") - npt.assert_almost_equal(adata.X.A, adata2.X.A) + npt.assert_almost_equal(adata.X.toarray(), adata2.X.toarray()) pdt.assert_index_equal(adata.obs_names, adata2.obs_names) pdt.assert_index_equal(adata.var_names, adata2.var_names)