Skip to content

Commit

Permalink
Fixture for Augur tests (#406)
Browse files Browse the repository at this point in the history
* Added fixture for test_augur.py with subsampled sc_sim_augur dataset obtained from the dataloader

* Deleted sc_sim.h5ad dataset from tests folder

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Lilly-May and pre-commit-ci[bot] authored Oct 22, 2023
1 parent ef8de38 commit e2ff0f8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 63 deletions.
Binary file removed tests/tools/sc_sim.h5ad
Binary file not shown.
120 changes: 57 additions & 63 deletions tests/tools/test_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,90 +16,87 @@ class TestAugur:
ag_lrc = pt.tl.Augur("logistic_regression_classifier", Params(random_state=42))
ag_rfr = pt.tl.Augur("random_forest_regressor", Params(random_state=42))

def test_load(self):
@pytest.fixture
def adata(self):
adata = pt.dt.sc_sim_augur()
adata = sc.pp.subsample(adata, n_obs=200, copy=True, random_state=10)

return adata

def test_load(self, adata):
"""Test if load function creates anndata objects."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
ag = pt.tl.Augur(estimator="random_forest_classifier")

loaded_adata = ag.load(sc_sim_adata)
loaded_df = ag.load(sc_sim_adata.to_df(), meta=sc_sim_adata.obs, cell_type_col="cell_type", label_col="label")
loaded_adata = ag.load(adata)
loaded_df = ag.load(adata.to_df(), meta=adata.obs, cell_type_col="cell_type", label_col="label")

assert loaded_adata.obs["y_"].equals(loaded_df.obs["y_"]) is True
assert sc_sim_adata.to_df().equals(loaded_adata.to_df()) is True and sc_sim_adata.to_df().equals(
loaded_df.to_df()
)
assert adata.to_df().equals(loaded_adata.to_df()) is True and adata.to_df().equals(loaded_df.to_df())

def test_random_forest_classifier(self):
def test_random_forest_classifier(self, adata):
"""Tests random forest for auc calculation."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata, results = self.ag_rfc.predict(
sc_sim_adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)
h_adata, results = self.ag_rfc.predict(
adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
)

assert results["CellTypeA"][2]["subsample_idx"] == 2
assert "augur_score" in adata.obs.columns
assert "augur_score" in h_adata.obs.columns
assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.634920, 0.933484, 0.902494])
assert "feature_importances" in results.keys()
assert len(set(results["summary_metrics"]["CellTypeA"])) == len(results["summary_metrics"]["CellTypeA"]) - 1

def test_logistic_regression_classifier(self):
def test_logistic_regression_classifier(self, adata):
"""Tests logistic classifier for auc calculation."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata, results = self.ag_lrc.predict(
sc_sim_adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)
h_adata, results = self.ag_lrc.predict(
adata, n_threads=4, n_subsamples=3, random_state=42, select_variance_features=False
)

assert "augur_score" in adata.obs.columns
assert "augur_score" in h_adata.obs.columns
assert np.allclose(results["summary_metrics"].loc["mean_augur_score"].tolist(), [0.691232, 0.955404, 0.972789])
assert "feature_importances" in results.keys()

def test_random_forest_regressor(self):
def test_random_forest_regressor(self, adata):
"""Tests random forest regressor for ccc calculation."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)

with pytest.raises(ValueError):
self.ag_rfr.predict(sc_sim_adata, n_threads=4, n_subsamples=3, random_state=42)
self.ag_rfr.predict(adata, n_threads=4, n_subsamples=3, random_state=42)

def test_classifier(self):
def test_classifier(self, adata):
"""Test run cross validation with classifier."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata = sc.pp.subsample(sc_sim_adata, n_obs=100, random_state=42, copy=True)
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)
adata_subsampled = sc.pp.subsample(adata, n_obs=100, random_state=42, copy=True)

cv = self.ag_rfc.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
cv = self.ag_rfc.run_cross_validation(
adata_subsampled, subsample_idx=1, folds=3, random_state=42, zero_division=0
)
auc = 0.786412
assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)])

sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_lrc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
cv = self.ag_lrc.run_cross_validation(sc_sim_adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
cv = self.ag_lrc.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
auc = 0.978673
assert any([isclose(cv["mean_auc"], auc, abs_tol=10**-3)])

def test_regressor(self):
def test_regressor(self, adata):
"""Test run cross validation with regressor."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
cv = self.ag_rfr.run_cross_validation(sc_sim_adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
adata = self.ag_rfc.load(adata)
cv = self.ag_rfr.run_cross_validation(adata, subsample_idx=1, folds=3, random_state=42, zero_division=0)
ccc = 0.168800
r2 = 0.149887
assert any([isclose(cv["mean_ccc"], ccc, abs_tol=10**-5), isclose(cv["mean_r2"], r2, abs_tol=10**-5)])

def test_subsample(self):
def test_subsample(self, adata):
"""Test default, permute and velocity subsampling process."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)
categorical_subsample = self.ag_rfc.draw_subsample(
adata=sc_sim_adata,
adata=adata,
augur_mode="default",
subsample_size=20,
feature_perc=0.3,
Expand All @@ -109,7 +106,7 @@ def test_subsample(self):
assert len(categorical_subsample.obs_names) == 40

non_categorical_subsample = self.ag_rfc.draw_subsample(
adata=sc_sim_adata,
adata=adata,
augur_mode="default",
subsample_size=20,
feature_perc=0.3,
Expand All @@ -119,17 +116,17 @@ def test_subsample(self):
assert len(non_categorical_subsample.obs_names) == 20

permut_subsample = self.ag_rfc.draw_subsample(
adata=sc_sim_adata,
adata=adata,
augur_mode="permute",
subsample_size=20,
feature_perc=0.3,
categorical=True,
random_state=42,
)
assert (sc_sim_adata.obs.loc[permut_subsample.obs.index, "y_"] != permut_subsample.obs["y_"]).any()
assert (adata.obs.loc[permut_subsample.obs.index, "y_"] != permut_subsample.obs["y_"]).any()

velocity_subsample = self.ag_rfc.draw_subsample(
adata=sc_sim_adata,
adata=adata,
augur_mode="velocity",
subsample_size=20,
feature_perc=0.3,
Expand All @@ -142,13 +139,12 @@ def test_multiclass(self):
"""Test multiclass evaluation."""
pass

def test_select_variance(self):
def test_select_variance(self, adata):
"""Test select variance implementation."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")
sc_sim_adata = self.ag_rfc.load(sc_sim_adata)
sc.pp.highly_variable_genes(sc_sim_adata)
adata = sc_sim_adata[sc_sim_adata.obs["cell_type"] == "CellTypeA"]
ad = self.ag_rfc.select_variance(adata, var_quantile=0.5, span=0.3, filter_negative_residuals=False)
adata = self.ag_rfc.load(adata)
sc.pp.highly_variable_genes(adata)
adata_cell_type = adata[adata.obs["cell_type"] == "CellTypeA"]
ad = self.ag_rfc.select_variance(adata_cell_type, var_quantile=0.5, span=0.3, filter_negative_residuals=False)

assert 3672 == len(ad.var.index[ad.var["highly_variable"]])

Expand All @@ -170,17 +166,15 @@ def test_params(self):
self.ag_rfr.create_estimator("random_forest_regressor", Params(unvalid=10))

@pytest.mark.skip("Computationally expensive")
def test_differential_prioritization(self):
def test_differential_prioritization(self, adata):
"""Test differential prioritization run."""
sc_sim_adata = sc.read_h5ad(f"{CWD}/sc_sim.h5ad")

ag = pt.tl.Augur("random_forest_classifier", Params(random_state=42))
ag.load(sc_sim_adata)
ag.load(adata)

adata, results1 = ag.predict(sc_sim_adata, n_threads=4, n_subsamples=3, random_state=2)
adata, results2 = ag.predict(sc_sim_adata, n_threads=4, n_subsamples=3, random_state=42)
adata, results1 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=2)
adata, results2 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=42)

a, permut1 = ag.predict(sc_sim_adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=2)
a, permut2 = ag.predict(sc_sim_adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=42)
a, permut1 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=2)
a, permut2 = ag.predict(adata, augur_mode="permute", n_threads=4, n_subsamples=100, random_state=42)
delta = ag.predict_differential_prioritization(results1, results2, permut1, permut2)
assert not np.isnan(delta["z"]).any()

0 comments on commit e2ff0f8

Please sign in to comment.