Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix set_obs throwing ValueError when Series aligns to MuData. #504

Merged
merged 4 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning][].

- Fix default value for `n_jobs` in `ir.tl.ir_query` that could lead to an error ([#498](https://github.com/scverse/scirpy/pull/498)).
- Update description of D50 diversity metric in documentation ([#499](https://github.com/scverse/scirpy/pull/498)).
- Fix `clonotype_modularity` not being able to store result in MuData in some cases ([#504](https://github.com/scverse/scirpy/pull/504)).
- Fix issue with creating sparse matrices from generators with the latest scipy version ([#504](https://github.com/scverse/scirpy/pull/504))

## v0.16.0

Expand Down
29 changes: 27 additions & 2 deletions src/scirpy/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,20 @@ def test_data_handler_no_airr():
DataHandler(adata, "airr", "airr")


def test_data_handler_get_obs():
@pytest.fixture
def mdata_with_smaller_adata():
adata_gex = AnnData(obs=pd.DataFrame(index=["c1", "c2", "c3"]).assign(both=[11, 12, 13]))
adata_airr = AnnData(obs=pd.DataFrame(index=["c3", "c4", "c5"]).assign(both=[14, 15, 16]))
mdata = MuData({"gex": adata_gex, "airr": adata_airr})
mdata["airr"].obs["airr_only"] = [3, 4, 5]

mdata.obs["mudata_only"] = [1, 2, 3, 4, 5]
mdata.obs["both"] = [np.nan, np.nan, 114, 115, 116]
return mdata


params = DataHandler(mdata, "airr")
def test_data_handler_get_obs(mdata_with_smaller_adata):
params = DataHandler(mdata_with_smaller_adata, "airr")
# can retrieve value from mudata
npt.assert_equal(params.get_obs("mudata_only").values, np.array([1, 2, 3, 4, 5]))
# Mudata takes precedence
Expand Down Expand Up @@ -109,6 +113,27 @@ def test_data_handler_get_obs():
)


@pytest.mark.parametrize(
"value,exception",
[
[pd.Series([1, 2, 3], index=["c1", "c3", "c4"]), None],
[pd.Series([1, 2, 3], index=["c1", "c3", "c8"]), None],
[pd.Series([1, 2, 3, 4, 5], index=["c1", "c2", "c3", "c4", "c5"]), None],
[[1, 2, 3], None],
[[1, 2, 3, 4], ValueError],
[[1, 2, 3, 4, 5], None],
[[1, 2, 3, 4, 5, 6], ValueError],
],
)
def test_data_handler_set_obs(mdata_with_smaller_adata, value, exception):
params = DataHandler(mdata_with_smaller_adata, "airr")
if exception is not None:
with pytest.raises(exception):
params.set_obs("test", value)
else:
params.set_obs("test", value)


def test_data_handler_initalize_from_object(adata_tra):
dh = DataHandler(adata_tra, "airr", "airr")
dh2 = DataHandler(dh)
Expand Down
11 changes: 10 additions & 1 deletion src/scirpy/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,20 @@ def _get_obs_col(self, column: str) -> pd.Series:
def set_obs(self, key: str, value: Union[pd.Series, Sequence[Any], np.ndarray]) -> None:
"""Store results in .obs of AnnData and MuData.

If `value` is not a Series, if the length is equal to the params.mdata, we assume it aligns to the
MuData object. Otherwise, if the length is equal to the params.adata, we assume it aligns to the
AnnData object. Otherwise, a ValueError is thrown.

The result will be written to `mdata.obs["{airr_mod}:{key}"]` and to `adata.obs[key]`.
"""
# index series with AnnData (in case MuData has different dimensions)
if not isinstance(value, pd.Series):
value = pd.Series(value, index=self.adata.obs_names)
if len(value) == self.data.shape[0]:
value = pd.Series(value, index=self.data.obs_names)
elif len(value) == self.adata.shape[0]:
value = pd.Series(value, index=self.adata.obs_names)
else:
raise ValueError("Provided values without index and can't align with either MuData or AnnData.")
if isinstance(self.data, MuData):
# write to both AnnData and MuData
if self._airr_mod is None:
Expand Down
2 changes: 1 addition & 1 deletion src/scirpy/util/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _get_sparse_from_igraph(graph, *, simplified, weight_attr=None):
shape = graph.vcount()
shape = (shape, shape)
if len(edges) > 0:
adj_mat = csr_matrix((weights, zip(*edges)), shape=shape)
adj_mat = csr_matrix((weights, list(zip(*edges))), shape=shape)
if simplified:
# make symmetrical and add diagonal
adj_mat = adj_mat + adj_mat.T - sparse.diags(adj_mat.diagonal()) + sparse.diags(np.ones(adj_mat.shape[0]))
Expand Down
Loading