Skip to content

Commit

Permalink
Small bug fixes and docs improvements (#409)
Browse files Browse the repository at this point in the history
* Small bug fixes and docs improvements: Changed pt.pl.dl.pairplot to a static method, fixed error messages in substract and add methods of PerturbationSpace, added docs parameter description for DBSCANSpace.compute and PseudobulkSpace.compute, added docs explanations for all parameters for pt.pl.scg.reg_mean_plot and pt.pl.scg.reg_var_plot

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Lilly-May and pre-commit-ci[bot] authored Oct 22, 2023
1 parent e2ff0f8 commit 62f8b14
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 20 deletions.
6 changes: 2 additions & 4 deletions pertpy/plot/_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def split_violins(

return ax

def pairplot(self, adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0") -> PairGrid:
@staticmethod
def pairplot(adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0") -> PairGrid:
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
Expand All @@ -75,10 +76,7 @@ def pairplot(self, adata: AnnData, celltype_key: str, color: str, sample_id: str
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
n_counts_key = "nCount_RNA", n_mpcs = 3)
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
#>>> dl_pl=pt.pl.dl()
#>>> dl_pl.pairplot(adata=adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
>>> pt.pl.dl.pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
#TODO: Is self parameter there on purpose -> create DialoguePlot object first?
"""
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
mean_mcps = mean_mcps.reset_index()
Expand Down
14 changes: 14 additions & 0 deletions pertpy/plot/_scgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def reg_mean_plot(
save: Specify if the plot should be saved or not.
gene_list: list of gene names to be plotted.
show: if `True`: will show to the plot after saving it.
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
legend: if `True`: plots a legend, defaults to `True`.
title: Set if you want the plot to display a title.
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
fontsize: Fontsize used for text in the plot, defaults to 14.
**kwargs:
Examples:
Expand Down Expand Up @@ -171,6 +178,13 @@ def reg_var_plot(
save: Specify if the plot should be saved or not.
gene_list: list of gene names to be plotted.
show: if `True`: will show to the plot after saving it.
top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra.
legend: if `True`: plots a legend, defaults to `True`.
title: Set if you want the plot to display a title.
verbose: Specify if you want information to be printed while creating the plot, defaults to `False`.
x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3.
y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8.
fontsize: Fontsize used for text in the plot, defaults to 14.
"""
import seaborn as sns

Expand Down
16 changes: 9 additions & 7 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def compute_control_diff( # type: ignore
)

if embedding_key is not None and embedding_key not in adata.obsm_keys():
raise ValueError(
f"Reference key {reference_key} not found in {target_col}. {reference_key} must be in obs column {target_col}."
)
raise ValueError(f"Embedding key {embedding_key} not found in obsm keys of the anndata.")

if layer_key is not None and layer_key not in adata.layers.keys():
raise ValueError(f"Layer {layer_key!r} does not exist in the anndata.")
Expand Down Expand Up @@ -124,6 +122,7 @@ def add(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
):
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
Expand All @@ -132,6 +131,7 @@ def add(
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
Examples:
Example usage with PseudobulkSpace:
Expand All @@ -145,7 +145,7 @@ def add(
for perturbation in perturbations:
if perturbation not in adata.obs_names:
raise ValueError(
f"Perturbation {reference_key} not found in adata.obs_names. {reference_key} must be in adata.obs_names."
f"Perturbation {perturbation} not found in adata.obs_names. {perturbation} must be in adata.obs_names."
)
new_pert_name += perturbation + "+"

Expand All @@ -156,7 +156,7 @@ def add(
"Run with ensure_consistency=True"
)
else:
adata = self.compute_control_diff(adata, copy=True, all_data=True)
adata = self.compute_control_diff(adata, copy=True, all_data=True, target_col=target_col)

data: dict[str, np.array] = {}

Expand Down Expand Up @@ -223,6 +223,7 @@ def subtract(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
):
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
Expand All @@ -231,6 +232,7 @@ def subtract(
perturbations: Perturbations to subtract,
reference_key: Perturbation source from which the perturbation subtraction starts
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
Examples:
Example usage with PseudobulkSpace:
Expand All @@ -244,7 +246,7 @@ def subtract(
for perturbation in perturbations:
if perturbation not in adata.obs_names:
raise ValueError(
f"Perturbation {reference_key} not found in adata.obs_names. {reference_key} must be in adata.obs_names."
f"Perturbation {perturbation} not found in adata.obs_names. {perturbation} must be in adata.obs_names."
)
new_pert_name += perturbation + "-"

Expand All @@ -255,7 +257,7 @@ def subtract(
"Run with ensure_consistency=True"
)
else:
adata = self.compute_control_diff(adata, copy=True, all_data=True)
adata = self.compute_control_diff(adata, copy=True, all_data=True, target_col=target_col)

data: dict[str, np.array] = {}

Expand Down
14 changes: 5 additions & 9 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def compute(
target_col: .obs column that stores the label of the perturbation applied to each cell.
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
**kwargs: Are passed to decoupler's get_pseuobulk.
Examples:
>>> import pertpy as pp
Expand Down Expand Up @@ -147,7 +148,7 @@ def compute( # type: ignore
adata: AnnData,
layer_key: str = None,
embedding_key: str = None,
cluster_key: str = None,
cluster_key: str = "k-means",
copy: bool = False,
return_object: bool = False,
**kwargs,
Expand All @@ -172,9 +173,6 @@ def compute( # type: ignore
if copy:
adata = adata.copy()

if cluster_key is None:
cluster_key = "k-means"

if layer_key is not None and embedding_key is not None:
raise ValueError("Please, select just either layer or embedding for computation.")

Expand Down Expand Up @@ -210,7 +208,7 @@ def compute( # type: ignore
adata: AnnData,
layer_key: str = None,
embedding_key: str = None,
cluster_key: str = None,
cluster_key: str = "dbscan",
copy: bool = True,
return_object: bool = False,
**kwargs,
Expand All @@ -221,9 +219,10 @@ def compute( # type: ignore
adata: Anndata object of size cells x genes
layer_key: If specified and exists in the adata, the clustering is done by using it. Otherwise, clustering is done with .X
embedding_key: if specified and exists in the adata, the clustering is done with that embedding. Otherwise, clustering is done with .X
cluster_key: name of the .obs column to store the cluster labels. Defaults to 'k-means'
cluster_key: name of the .obs column to store the cluster labels. Defaults to 'dbscan'
copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata
return_object: if True returns the clustering object
**kwargs: Are passed to sklearn's DBSCAN.
Examples:
>>> import pertpy as pt
Expand All @@ -234,9 +233,6 @@ def compute( # type: ignore
if copy:
adata = adata.copy()

if cluster_key is None:
cluster_key = "dbscan"

if embedding_key is not None:
if embedding_key not in adata.obsm_keys():
raise ValueError(f"Embedding {embedding_key!r} does not exist in the .obsm attribute.")
Expand Down

0 comments on commit 62f8b14

Please sign in to comment.