Skip to content

Commit

Permalink
Expose cbar kwargs (#244)
Browse files Browse the repository at this point in the history
* Add cbar_kwargs to heatmap

* Add tests
  • Loading branch information
michalk8 authored Jan 26, 2021
1 parent e35c124 commit 3bbb5f3
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 14 deletions.
2 changes: 2 additions & 0 deletions squidpy/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def decorator2(obj: Any) -> Any:
The title of the plot.
cmap
Continuous colormap to use.
cbar_kwargs
Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`.
{_cat_plotting}"""
_plotting_returns = """\
Nothing, just plots the and optionally saves the plot.
Expand Down
4 changes: 4 additions & 0 deletions squidpy/pl/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def interaction_matrix(
title: Optional[str] = None,
cmap: str = "viridis",
palette: Palette_t = None,
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
Expand Down Expand Up @@ -189,6 +190,7 @@ def interaction_matrix(
annotate=annotate,
figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize,
dpi=dpi,
cbar_kwargs=cbar_kwargs,
**kwargs,
)

Expand All @@ -206,6 +208,7 @@ def nhood_enrichment(
title: Optional[str] = None,
cmap: str = "viridis",
palette: Palette_t = None,
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
Expand Down Expand Up @@ -251,6 +254,7 @@ def nhood_enrichment(
annotate=annotate,
figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize,
dpi=dpi,
cbar_kwargs=cbar_kwargs,
**kwargs,
)

Expand Down
30 changes: 16 additions & 14 deletions squidpy/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy
from types import MappingProxyType
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -482,10 +483,12 @@ def _heatmap(
annotate: bool = True,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> mpl.figure.Figure:

_assert_categorical_obs(adata, key=key)

cbar_kwargs = dict(cbar_kwargs)
fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)

if method is not None:
Expand All @@ -495,15 +498,14 @@ def _heatmap(

row_order = row_order[::-1]
row_labels = adata.obs[key][row_order]

data = adata[row_order, col_order].X

row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order))

row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)

minn, maxx = np.nanmin(data), np.nanmax(data)
norm = mpl.colors.Normalize(vmin=minn, vmax=maxx)
norm = mpl.colors.Normalize(vmin=kwargs.pop("vmin", np.nanmin(data)), vmax=kwargs.pop("vmax", np.nanmax(data)))
cont_cmap = copy(plt.get_cmap(cont_cmap))
cont_cmap.set_bad(color="grey")

Expand All @@ -526,26 +528,26 @@ def _heatmap(
sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
col_ax.axis("off")

_ = mpl.colorbar.ColorbarBase(
cax,
cmap=cont_cmap,
norm=norm,
ticks=np.linspace(np.nanmin(data), np.nanmax(data), 10),
_ = fig.colorbar(
im,
cax=cax,
ticks=np.linspace(norm.vmin, norm.vmax, 10),
orientation="vertical",
format="%0.2f",
**cbar_kwargs,
)

# column labels colorbar
c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal")
c.set_ticks([])
(col_cats if method is None else col_ax).set_title(title)

# row labels colorbar
c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left")
c.set_ticks(np.arange(n_cls) + 0.5)
c.set_ticklabels(row_labels)
c.set_label(key)

if method is not None:
col_ax.set_title(title)
else:
col_cats.set_title(title)

return fig


Expand Down
Binary file added tests/_images/Heatmap_cbar_kwargs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Heatmap_cbar_vmin_vmax.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions tests/tests_plotting/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def test_tol_plot_co_occurrence_palette(self, adata_palette: AnnData):
self.compare("Graph_co_occurrence_palette", tolerance=70)


class TestHeatmap(PlotTester, metaclass=PlotTesterMeta):
def test_plot_cbar_vmin_vmax(self, adata: AnnData):
gr.spatial_neighbors(adata)
gr.nhood_enrichment(adata, cluster_key=C_KEY)

pl.nhood_enrichment(adata, cluster_key=C_KEY, vmin=10, vmax=20)

def test_plot_cbar_kwargs(self, adata: AnnData):
gr.spatial_neighbors(adata)
gr.nhood_enrichment(adata, cluster_key=C_KEY)

pl.nhood_enrichment(adata, cluster_key=C_KEY, cbar_kwargs={"label": "FOOBARBAZQUUX", "filled": False})


class TestLigrec(PlotTester, metaclass=PlotTesterMeta):
def test_invalid_type(self):
with pytest.raises(TypeError, match=r"Expected `adata` .+ found `int`."):
Expand Down

0 comments on commit 3bbb5f3

Please sign in to comment.