Skip to content

Commit

Permalink
scipy 1.14 update
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 8, 2024
1 parent 54ec546 commit 7e6bf5f
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion supirfactor_dynamical/_io/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_data_from_ad(

if densify:
try:
_output = _output.A
_output = _output.toarray()
except AttributeError:
pass

Expand Down
6 changes: 3 additions & 3 deletions supirfactor_dynamical/_utils/_adata_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import anndata as ad
import numpy as np

import scipy.sparse as sps

from pandas.api.types import is_float_dtype

Expand Down Expand Up @@ -112,7 +112,7 @@ def load_data_files_jtb_2023(
count_data = count_scaling.fit_transform(count_data)

try:
count_data = count_data.A
count_data = count_data.toarray()
except AttributeError:
pass

Expand Down Expand Up @@ -226,7 +226,7 @@ def _shuffle_data(
return _sim_ints(
pvec / pvec.sum(),
np.full(pvec.shape, 3099, dtype=int),
sparse=hasattr(x, 'A')
sparse=sps.issparse(x)
)

elif sim_velo:
Expand Down
3 changes: 2 additions & 1 deletion supirfactor_dynamical/datasets/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ def _get_data(data, idx, keep_sparse=False):

if issparse(_data) and not keep_sparse:

_data = _data.toarray()
_data = torch.Tensor(
_data.A.reshape(-1) if _data.shape[0] == 1 else _data.A
_data.reshape(-1) if _data.shape[0] == 1 else _data
)

return _data
Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_bulk_sampling(self):
self.assertIsNone(td.strat_idxes)

try:
x = self.adata.X.A
x = self.adata.X.toarray()
except AttributeError:
x = self.adata.X

Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_h5ad_sparse(self):

adata2 = _read_ad(self.temp_file_name, "adata")

npt.assert_almost_equal(adata.X.A, adata2.X.A)
npt.assert_almost_equal(adata.X.toarray(), adata2.X.toarray())
pdt.assert_index_equal(adata.obs_names, adata2.obs_names)
pdt.assert_index_equal(adata.var_names, adata2.var_names)

Expand Down

0 comments on commit 7e6bf5f

Please sign in to comment.