Skip to content

Commit

Permalink
Update ligrec dendrogram (#236)
Browse files Browse the repository at this point in the history
* Update ligrec dendrogram

* Add dendrogram options

* Add tests

* Fix tests
  • Loading branch information
michalk8 authored Jan 22, 2021
1 parent 263a47e commit e35c124
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 57 deletions.
7 changes: 7 additions & 0 deletions squidpy/_constants/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ class Centrality(ModeEnum): # noqa: D101
DEGREE = "degree_centrality"
CLUSTERING = "average_clustering"
CLOSENESS = "closeness_centrality"


@unique
class DendrogramAxis(ModeEnum): # noqa: D101
INTERACTING_MOLS = "interacting_molecules"
INTERACTING_CLUSTERS = "interacting_clusters"
BOTH = "both"
39 changes: 25 additions & 14 deletions squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@

from abc import ABC
from types import MappingProxyType
from typing import Any, List, Tuple, Union, Mapping, Optional, Sequence, TYPE_CHECKING
from typing import (
Any,
List,
Tuple,
Union,
Mapping,
Iterable,
Optional,
Sequence,
TYPE_CHECKING,
)
from functools import partial
from itertools import product
from collections import namedtuple
Expand Down Expand Up @@ -233,7 +243,18 @@ def prepare(
"""
complex_policy = ComplexPolicy(complex_policy)

if isinstance(interactions, Sequence):
if isinstance(interactions, Mapping):
interactions = pd.DataFrame(interactions)

if isinstance(interactions, pd.DataFrame):
if SOURCE not in interactions.columns:
raise KeyError(f"Column `{SOURCE!r}` is not in `interactions`.")
if TARGET not in interactions.columns:
raise KeyError(f"Column `{TARGET!r}` is not in `interactions`.")

self._interactions = interactions.copy()
elif isinstance(interactions, Iterable):
interactions = tuple(interactions)
if not len(interactions):
raise ValueError("No interactions were specified.")

Expand All @@ -245,20 +266,10 @@ def prepare(
if not all(len(i) == 2 for i in interactions):
raise ValueError("Not all interactions are of length `2`.")

interactions = pd.DataFrame(interactions, columns=[SOURCE, TARGET])
elif isinstance(interactions, Mapping):
interactions = pd.DataFrame(interactions)

if isinstance(interactions, pd.DataFrame):
if SOURCE not in interactions.columns:
raise KeyError(f"Column `{SOURCE!r}` is not in `interactions`.")
if TARGET not in interactions.columns:
raise KeyError(f"Column `{TARGET!r}` is not in `interactions`.")
self._interactions = interactions.copy()
self._interactions = pd.DataFrame(interactions, columns=[SOURCE, TARGET])
else:
raise TypeError(
f"Expected either a `pandas.DataFrame`, `dict`, `tuple`, `list` or `str`, "
f"found `{type(interactions).__name__}`"
f"Expected either a `pandas.DataFrame`, `dict` or `iterable`, found `{type(interactions).__name__}`"
)
if TYPE_CHECKING:
assert isinstance(self.interactions, pd.DataFrame)
Expand Down
110 changes: 81 additions & 29 deletions squidpy/pl/_ligrec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Tuple, Union, Optional, Sequence
from typing import Any, Tuple, Union, Mapping, Optional, Sequence, TYPE_CHECKING
from pathlib import Path
import inspect

from scanpy import logging as logg
from anndata import AnnData
import scanpy as sc

from scipy.cluster import hierarchy as sch
import numpy as np
import pandas as pd

Expand All @@ -15,8 +15,9 @@

from squidpy._docs import d
from squidpy._utils import verbosity, _unique_order_preserving
from squidpy.pl._utils import save_fig
from squidpy.pl._utils import save_fig, _dendrogram, _filter_kwargs
from squidpy.gr._ligrec import LigrecResult
from squidpy._constants._constants import DendrogramAxis
from squidpy._constants._pkg_constants import Key

_SEP = " | "
Expand Down Expand Up @@ -102,7 +103,7 @@ def ligrec(
means_range: Tuple[float, float] = (-np.inf, np.inf),
pvalue_threshold: float = 1.0,
remove_empty_interactions: bool = True,
dendrogram: bool = False,
dendrogram: Optional[str] = None,
alpha: Optional[float] = 0.001,
swap_axes: bool = False,
title: Optional[str] = None,
Expand Down Expand Up @@ -132,7 +133,13 @@ def ligrec(
pvalue_threshold
Only show interactions with p-value <= ``pvalue_threshold``.
dendrogram
Whether to show dendrogram.
How to cluster based on the p-values. Valid options are:
- `None` - do not perform clustering.
- `'interacting_molecules'` - cluster the interacting molecules.
- `'interacting_clusters'` - cluster the interacting clusters.
- `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown.
swap_axes
Whether to show the cluster combinations as rows and the interacting pairs as columns.
title
Expand All @@ -147,6 +154,32 @@ def ligrec(
-------
%(plotting_returns)s
"""

def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]:
z_var = sch.linkage(
adata.X,
metric="correlation",
method=linkage,
optimal_ordering=adata.n_obs <= 1500, # matplotlib will most likely give up first
)
dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True)
# this is what the DotPlot requires
return {
"linkage": z_var,
"groupby": ["groups"],
"cor_method": "pearson",
"use_rep": None,
"linkage_method": linkage,
"categories_ordered": dendro_info["ivl"],
"categories_idx_ordered": dendro_info["leaves"],
"dendrogram_info": dendro_info,
}

if dendrogram is not None:
dendrogram = DendrogramAxis(dendrogram) # type: ignore[assignment]
if TYPE_CHECKING:
assert isinstance(dendrogram, DendrogramAxis)

if isinstance(adata, AnnData):
if cluster_key is None:
raise ValueError("Please provide `cluster_key` when supplying an `AnnData` object.")
Expand Down Expand Up @@ -187,7 +220,7 @@ def ligrec(
means: pd.DataFrame = adata.means.loc[:, (source_groups, target_groups)]

if pvals.empty:
raise ValueError("No clusters have been selected.")
raise ValueError("No valid clusters have been selected.")

means = means[(means >= means_range[0]) & (means <= means_range[1])]
pvals = pvals[pvals <= pvalue_threshold]
Expand All @@ -209,6 +242,12 @@ def ligrec(
raise ValueError("After removing columns with only NaN interactions, none remain.")

start, label_ranges = 0, {}

if dendrogram == DendrogramAxis.INTERACTING_CLUSTERS:
# rows are now cluster combinations, not interacting pairs
pvals = pvals.T
means = means.T

for cls, size in (pvals.groupby(level=0, axis=1)).size().to_dict().items():
label_ranges[cls] = (start, start + size - 1)
start += size
Expand All @@ -227,26 +266,34 @@ def ligrec(
var = pd.DataFrame(pvals.columns)
var.set_index((var.columns[0]), inplace=True)

adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index, pvals.index)}, var=var)
adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var)
adata.obs_names = pvals.index
minn = np.nanmin(adata.X)
delta = np.nanmax(adata.X) - minn
adata.X = (adata.X - minn) / delta

if dendrogram:
sc.pp.pca(adata)
sc.tl.dendrogram(adata, groupby="groups", key_added="dendrogram")
try:
if dendrogram == DendrogramAxis.BOTH:
row_order, col_order, _, _ = _dendrogram(
adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500
)
adata = adata[row_order, :][:, col_order]
pvals = pvals.iloc[row_order, :].iloc[:, col_order]
means = means.iloc[row_order, :].iloc[:, col_order]
elif dendrogram is not None:
adata.uns["dendrogram"] = get_dendrogram(adata)
except IndexError:
# just in case pandas indexing fails
raise
except Exception as e:
logg.warning(f"Unable to create a dendrogram. Reason: `{e}`")
dendrogram = None

kwargs["dot_edge_lw"] = 0
kwargs.setdefault("cmap", "viridis")
kwargs.setdefault("grid", True)
kwargs.pop("color_on", None) # interferes with tori

style_args = {k for k in inspect.signature(sc.pl.DotPlot.style).parameters.keys()} # noqa: C416
style_dict = {k: v for k, v in kwargs.items() if k in style_args}

legend_args = {k for k in inspect.signature(sc.pl.DotPlot.legend).parameters.keys()} # noqa: C416
legend_dict = {k: v for k, v in kwargs.items() if k in legend_args}

dp = (
CustomDotplot(
delta=delta,
Expand All @@ -257,21 +304,21 @@ def ligrec(
dot_color_df=means,
dot_size_df=pvals,
title=title,
var_group_labels=list(label_ranges.keys()),
var_group_positions=list(label_ranges.values()),
var_group_labels=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.keys()),
var_group_positions=None if dendrogram == DendrogramAxis.BOTH else list(label_ranges.values()),
standard_scale=None,
figsize=figsize,
)
.style(
**style_dict,
**_filter_kwargs(sc.pl.DotPlot.style, kwargs),
)
.legend(
size_title=r"$-\log_{10} ~ P$",
colorbar_title=r"$log_2(\frac{molecule_1 + molecule_2}{2} + 1)$",
**legend_dict,
**_filter_kwargs(sc.pl.DotPlot.legend, kwargs),
)
)
if dendrogram:
if dendrogram in (DendrogramAxis.INTERACTING_MOLS, DendrogramAxis.INTERACTING_CLUSTERS):
# ignore the warning about mismatching groups
with verbosity(0):
dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram")
Expand All @@ -280,19 +327,24 @@ def ligrec(

dp.make_figure()

labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels()
for text in labs:
text.set_text(text.get_text().split(_SEP)[1])
if swap_axes:
dp.ax_dict["mainplot_ax"].set_yticklabels(labs)
else:
dp.ax_dict["mainplot_ax"].set_xticklabels(labs)
if dendrogram != DendrogramAxis.BOTH:
# remove the target part in: source | target
labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels()
for text in labs:
text.set_text(text.get_text().split(_SEP)[1])
if swap_axes:
dp.ax_dict["mainplot_ax"].set_yticklabels(labs)
else:
dp.ax_dict["mainplot_ax"].set_xticklabels(labs)

if alpha is not None:
yy, xx = np.where(pvals.values >= -np.log10(alpha))
if len(xx) and len(yy):
# for dendrogram='both', they are already re-ordered
mapper = (
np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) if dendrogram else np.arange(len(pvals))
np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"])
if "dendrogram" in adata.uns
else np.arange(len(pvals))
)
logg.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`")
ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot)
Expand Down
37 changes: 28 additions & 9 deletions squidpy/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from copy import copy
from typing import Any, List, Tuple, Union, Callable, Optional, Sequence, TYPE_CHECKING
from typing import (
Any,
Dict,
List,
Tuple,
Union,
Mapping,
Callable,
Optional,
Sequence,
TYPE_CHECKING,
)
from inspect import signature
from pathlib import Path
from functools import wraps
from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -10,7 +22,7 @@

from numba import njit, prange
from scipy.sparse import issparse, spmatrix
from scipy.cluster import hierarchy
from scipy.cluster import hierarchy as sch
from pandas._libs.lib import infer_dtype
from pandas.core.dtypes.common import (
is_bool_dtype,
Expand Down Expand Up @@ -477,7 +489,7 @@ def _heatmap(
fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)

if method is not None:
row_order, col_order, row_link, col_link = _dendrogram(adata.X, method)
row_order, col_order, row_link, col_link = _dendrogram(adata.X, method, optimal_ordering=adata.n_obs <= 1500)
else:
row_order = col_order = np.arange(len(adata.uns[Key.uns.colors(key)]))

Expand Down Expand Up @@ -511,7 +523,7 @@ def _heatmap(
cax = divider.append_axes("right", size="1%", pad=0.1)
if method is not None: # cluster rows but don't plot dendrogram
col_ax = divider.append_axes("top", size="5%")
hierarchy.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
col_ax.axis("off")

_ = mpl.colorbar.ColorbarBase(
Expand All @@ -537,16 +549,23 @@ def _heatmap(
return fig


def _dendrogram(data: np.array, method: str) -> Tuple[List[int], List[int], List[int], List[int]]:
def _filter_kwargs(func: Callable[..., Any], kwargs: Mapping[str, Any]) -> Dict[str, Any]:
style_args = {k for k in signature(func).parameters.keys()} # noqa: C416
return {k: v for k, v in kwargs.items() if k in style_args}


def _dendrogram(data: np.array, method: str, **kwargs: Any) -> Tuple[List[int], List[int], List[int], List[int]]:
link_kwargs = _filter_kwargs(sch.linkage, kwargs)
dendro_kwargs = _filter_kwargs(sch.dendrogram, kwargs)

# Row-cluster
row_link = hierarchy.linkage(data, method=method)
row_dendro = hierarchy.dendrogram(row_link, no_plot=True)
row_link = sch.linkage(data, method=method, **link_kwargs)
row_dendro = sch.dendrogram(row_link, no_plot=True, **dendro_kwargs)
row_order = row_dendro["leaves"]

# Column-cluster
col_link = hierarchy.linkage(data.T, method=method)
col_dendro = hierarchy.dendrogram(col_link, no_plot=True)
col_link = sch.linkage(data.T, method=method, **link_kwargs)
col_dendro = sch.dendrogram(col_link, no_plot=True, **dendro_kwargs)
col_order = col_dendro["leaves"]

return row_order, col_order, row_link, col_link
Binary file removed tests/_images/Ligrec_dendrogram.png
Binary file not shown.
Binary file added tests/_images/Ligrec_dendrogram_both.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Ligrec_dendrogram_clusters.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Ligrec_dendrogram_pairs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Ligrec_swap_axes_dedrogram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 14 additions & 4 deletions tests/tests_plotting/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_invalid_means_range_size(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, means_range=[0, 1, 2])

def test_invalid_clusters(self, ligrec_result: LigrecResult):
with pytest.raises(ValueError, match=r"No clusters have been selected."):
with pytest.raises(ValueError, match=r"No valid clusters have been selected."):
pl.ligrec(ligrec_result, source_groups="foo", target_groups="bar")

def test_all_interactions_empty(self, ligrec_result: LigrecResult):
Expand Down Expand Up @@ -143,14 +143,24 @@ def test_plot_pvalue_threshold(self, ligrec_result: LigrecResult):
def test_plot_means_range(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, means_range=(0.5, 1))

def test_plot_dendrogram(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, dendrogram=True)
def test_plot_dendrogram_pairs(self, ligrec_result: LigrecResult):
np.random.seed(42)
pl.ligrec(ligrec_result, dendrogram="interacting_molecules")

def test_plot_dendrogram_clusters(self, ligrec_result: LigrecResult):
# this currently "fails" (i.e. no dendrogram)
np.random.seed(42)
pl.ligrec(ligrec_result, dendrogram="interacting_clusters")

def test_plot_dendrogram_both(self, ligrec_result: LigrecResult):
np.random.seed(42)
pl.ligrec(ligrec_result, dendrogram="both")

def test_plot_swap_axes(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, swap_axes=True)

def test_plot_swap_axes_dedrogram(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, swap_axes=True, dendrogram=True)
pl.ligrec(ligrec_result, swap_axes=True, dendrogram="interacting_molecules")

def test_plot_alpha(self, ligrec_result: LigrecResult):
pl.ligrec(ligrec_result, alpha=1)
Expand Down
Loading

0 comments on commit e35c124

Please sign in to comment.