Skip to content

Commit

Permalink
Use precomputed clustering (#406)
Browse files Browse the repository at this point in the history
* cluster only when clustering key does not exist or if forced

* add tests for precomputed clusters

* allow passing arguments to cluster_optimal_resolution from isolated label score

* fix docstring issues

* add test for isolated labe score with precomputed clusters

* fix neighbors check

* fix resolution getter function

* update explicit get_resolutions functions

* return no isolated labels when minimum number of batches per label is the same as total batches

* include update docstrings for clustering

* add use_rep to optimal clustering

* use graph connectivity for scanorama
  • Loading branch information
mumichae committed Apr 22, 2024
1 parent 127e41e commit 9c5a936
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 65 deletions.
17 changes: 17 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ For these, you need to additionally provide the corresponding label column of ``
:skip: runTrVaep
:skip: issparse


Clustering
----------
.. currentmodule:: scib.metrics

After integration, one of the first ways to determine the quality of the integration is to cluster the integrated data and compare the clusters to the original annotations.
This is exactly what some of the metrics do.

.. autosummary::
:toctree: api/

cluster_optimal_resolution
get_resolutions
opt_louvain


Metrics
-------

Expand Down Expand Up @@ -184,6 +200,7 @@ Some parts of metrics can be used individually, these are listed below.
:toctree: api/

cluster_optimal_resolution
get_resolutions
lisi_graph
pcr
pc_regression
106 changes: 70 additions & 36 deletions scib/metrics/clustering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
Expand All @@ -8,11 +9,27 @@
from .nmi import nmi


def get_resolutions(n=20, min=0.1, max=2):
min = np.max([1, int(min * 10)])
max = np.max([min, max * 10])
frac = n / 10
return [frac * x / n for x in range(min, max + 1)]
def get_resolutions(n=20, min=0, max=2):
"""
Get equally spaced resolutions for optimised clustering
:param n: number of resolutions
:param min: minimum resolution
:param max: maximum resolution
.. code-block:: python
scib.cl.get_resolutions(n=10)
Output:
.. code-block::
[0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
"""
res_range = max - min
return [res_range * (x + 1) / n for x in range(n)]


def cluster_optimal_resolution(
Expand All @@ -23,7 +40,7 @@ def cluster_optimal_resolution(
metric=None,
resolutions=None,
use_rep=None,
force=True,
force=False,
verbose=True,
return_all=False,
metric_kwargs=None,
Expand All @@ -36,14 +53,15 @@ def cluster_optimal_resolution(
:param adata: anndata object
:param label_key: name of column in adata.obs containing biological labels to be
optimised against
:param cluster_key: name of column to be added to adata.obs during clustering.
Will be overwritten if exists and ``force=True``
:param cluster_key: name and prefix of columns to be added to adata.obs during clustering.
Each resolution will be saved under "{cluster_key}_{resolution}", while the optimal clustering will be under ``cluster_key``.
If ``force=True`` and one of the keys already exists, it will be overwritten.
:param cluster_function: a clustering function that takes an anndata.Anndata object. Default: Leiden clustering
:param metric: function that computes the cost to be optimised over. Must take as
arguments ``(adata, label_key, cluster_key, **metric_kwargs)`` and returns a number for maximising
Default is :func:`~scib.metrics.nmi()`
:param resolutions: list of resolutions to be optimised over. If ``resolutions=None``,
default resolutions of 10 values ranging between 0.1 and 2 will be used
by default 10 equally spaced resolutions ranging between 0 and 2 will be used (see :func:`~scib.metrics.get_resolutions`)
:param use_rep: key of embedding to use only if ``adata.uns['neighbors']`` is not
defined, otherwise will be ignored
:param force: whether to overwrite the cluster assignments in the ``.obs[cluster_key]``
Expand All @@ -56,22 +74,33 @@ def cluster_optimal_resolution(
``res_max``: resolution of maximum score;
``score_max``: maximum score;
``score_all``: ``pd.DataFrame`` containing all scores at resolutions. Can be used to plot the score profile.
If you specify an embedding that was not used for the kNN graph (i.e. ``adata.uns["neighbors"]["params"]["use_rep"]`` is not the same as ``use_rep``),
the neighbors will be recomputed in-place.
"""
if cluster_key in adata.obs.columns:
if force:
print(
f"WARNING: cluster key {cluster_key} already exists in adata.obs and will be overwritten because "
"force=True "
)
else:
raise ValueError(
f"cluster key {cluster_key} already exists in adata, please remove the key or choose a different "
"name. If you want to force overwriting the key, specify `force=True` "

def call_cluster_function(adata, res, resolution_key, cluster_function, **kwargs):
if resolution_key in adata.obs.columns:
warnings.warn(
f"Overwriting existing key {resolution_key} in adata.obs", stacklevel=2
)

# check or recompute neighbours
knn_rep = adata.uns.get("neighbors", {}).get("params", {}).get("use_rep")
if use_rep is not None and use_rep != knn_rep:
print(f"Recompute neighbors on rep {use_rep} instead of {knn_rep}")
sc.pp.neighbors(adata, use_rep=use_rep)

# call clustering function
print(f"Cluster for {resolution_key} with {cluster_function.__name__}")
cluster_function(adata, resolution=res, key_added=resolution_key, **kwargs)

if cluster_function is None:
cluster_function = sc.tl.leiden

if cluster_key is None:
cluster_key = cluster_function.__name__

if metric is None:
metric = nmi

Expand All @@ -86,30 +115,27 @@ def cluster_optimal_resolution(
clustering = None
score_all = []

if use_rep is None:
try:
adata.uns["neighbors"]
except KeyError:
raise RuntimeError(
"Neighbours must be computed when setting use_rep to None"
for res in resolutions:
resolution_key = f"{cluster_key}_{res}"

# check if clustering exists
if resolution_key not in adata.obs.columns or force:
call_cluster_function(
adata, res, resolution_key, cluster_function, **kwargs
)
else:
print(f"Compute neighbors on rep {use_rep}")
sc.pp.neighbors(adata, use_rep=use_rep)

for res in resolutions:
cluster_function(adata, resolution=res, key_added=cluster_key, **kwargs)
score = metric(adata, label_key, cluster_key, **metric_kwargs)
if verbose:
print(f"resolution: {res}, {metric.__name__}: {score}")
# score cluster resolution
score = metric(adata, label_key, resolution_key, **metric_kwargs)
score_all.append(score)

if verbose:
print(f"resolution: {res}, {metric.__name__}: {score}", flush=True)

# optimise score
if score_max < score:
score_max = score
res_max = res
clustering = adata.obs[cluster_key]
del adata.obs[cluster_key]
clustering = adata.obs[resolution_key]

if verbose:
print(f"optimised clustering against {label_key}")
Expand All @@ -120,10 +146,16 @@ def cluster_optimal_resolution(
zip(resolutions, score_all), columns=["resolution", "score"]
)

# save optimal clustering in adata.obs
if cluster_key in adata.obs.columns:
warnings.warn(
f"Overwriting existing key {cluster_key} in adata.obs", stacklevel=2
)
adata.obs[cluster_key] = clustering

if return_all:
return res_max, score_max, score_all
return res_max, score_max


@deprecated
Expand All @@ -142,6 +174,8 @@ def opt_louvain(
):
"""Optimised Louvain clustering
DEPRECATED: Use :func:`~scib.metrics.cluster_optimal_resolution` instead
Louvain clustering with resolution optimised against a metric
:param adata: anndata object
Expand Down
60 changes: 53 additions & 7 deletions scib/metrics/isolated_labels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import pandas as pd
from sklearn.metrics import f1_score, silhouette_samples

Expand All @@ -9,8 +11,11 @@ def isolated_labels_f1(
label_key,
batch_key,
embed,
cluster_key="iso_label",
resolutions=None,
iso_threshold=None,
verbose=True,
**kwargs,
):
"""Isolated label score F1
Expand All @@ -25,11 +30,17 @@ def isolated_labels_f1(
:param iso_threshold: max number of batches per label for label to be considered as
isolated, if iso_threshold is integer.
If ``iso_threshold=None``, consider minimum number of batches that labels are present in
:param cluster_key: clustering key prefix to look or recompute for each resolution in resolutions.
Is passed to :func:`~scib.metrics.cluster_optimal_resolution`
:param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:param verbose:
:params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:return: Mean of F1 scores over all isolated labels
This function performs clustering on a kNN graph and can be applied to all integration output types.
For this metric the ``adata`` needs a kNN graph.
For this metric the ``adata`` needs a kNN graph and can optionally make use of precomputed clustering (see example below).
The precomputed clusters must be saved under ``adata.obs[cluster_key]`` as well as ``adata.obs[f"{cluster_key}_{resolution}"]`` for all resolutions.
See :ref:`preprocessing` for more information on preprocessing.
**Examples**
Expand All @@ -49,15 +60,25 @@ def isolated_labels_f1(
# knn output
scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype")
# use precomputed clustering
scib.cl.cluster_optimal_resolution(adata, cluster_key="iso_label", label_key="celltype")
scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype")
# overwrite existing clustering
scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype", force=True)
"""
return isolated_labels(
adata,
label_key=label_key,
batch_key=batch_key,
embed=embed,
cluster=True,
cluster_key=cluster_key,
resolutions=resolutions,
iso_threshold=iso_threshold,
verbose=verbose,
**kwargs,
)


Expand All @@ -84,6 +105,7 @@ def isolated_labels_asw(
If ``iso_threshold=None``, consider minimum number of batches that labels are present in
:param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores.
:param verbose:
:params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:return: Mean of ASW over all isolated labels
The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding
Expand Down Expand Up @@ -125,10 +147,13 @@ def isolated_labels(
batch_key,
embed,
cluster=True,
cluster_key="iso_label",
resolutions=None,
iso_threshold=None,
scale=True,
return_all=False,
verbose=True,
**kwargs,
):
"""Isolated label score
Expand All @@ -146,9 +171,12 @@ def isolated_labels(
:param iso_threshold: max number of batches per label for label to be considered as
isolated, if iso_threshold is integer.
If iso_threshold=None, consider minimum number of batches that labels are present in
:param cluster_key: name of key to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores.
:param return_all: return scores for all isolated labels instead of aggregated mean
:param verbose:
:params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:return:
Mean of scores for each isolated label
or dictionary of scores for each label if `return_all=True`
Expand All @@ -158,6 +186,8 @@ def isolated_labels(
isolated_labels = get_isolated_labels(
adata, label_key, batch_key, iso_threshold, verbose
)
if verbose:
print(f"isolated labels: {isolated_labels}")

# 2. compute isolated label score for each isolated label
scores = {}
Expand All @@ -171,9 +201,12 @@ def isolated_labels(
label_key,
label,
embed,
cluster,
cluster_key=cluster_key,
cluster=cluster,
scale=scale,
verbose=verbose,
resolutions=resolutions,
**kwargs,
)
scores[label] = score
scores = pd.Series(scores)
Expand All @@ -189,10 +222,12 @@ def score_isolated_label(
label_key,
isolated_label,
embed,
cluster_key,
cluster=True,
iso_label_key="iso_label",
resolutions=None,
scale=True,
verbose=False,
**kwargs,
):
"""
Compute label score for a single label
Expand All @@ -203,10 +238,12 @@ def score_isolated_label(
:param embed: embedding to be passed to opt_louvain, if adata.uns['neighbors'] is missing
:param cluster: if True, compute clustering-based F1 score, otherwise compute
silhouette score on grouping of isolated label vs all other remaining labels
:param iso_label_key: name of key to use for cluster assignment for F1 score or
:param cluster_key: name of key to use for cluster assignment for F1 score or
isolated-vs-rest assignment for silhouette score
:param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores.
:param verbose:
:params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution`
:return:
Isolated label score
"""
Expand All @@ -233,13 +270,15 @@ def max_f1(adata, label_key, cluster_key, label, argmax=False):
cluster_optimal_resolution(
adata,
label_key,
cluster_key=iso_label_key,
cluster_key=cluster_key,
use_rep=embed,
metric=max_f1,
metric_kwargs={"label": isolated_label},
verbose=False,
resolutions=resolutions,
force=False,
verbose=verbose,
)
score = max_f1(adata, label_key, iso_label_key, isolated_label, argmax=False)
score = max_f1(adata, label_key, cluster_key, isolated_label, argmax=False)
else:
# AWS score between isolated label vs rest
if "silhouette_temp" not in adata.obs:
Expand Down Expand Up @@ -275,6 +314,13 @@ def get_isolated_labels(adata, label_key, batch_key, iso_threshold, verbose):
if iso_threshold is None:
iso_threshold = batch_per_lab.min().tolist()[0]

if iso_threshold == adata.obs[batch_key].nunique():
warnings.warn(
"iso_threshold is equal to number of batches in data, no isolated labels will be found",
stacklevel=2,
)
return []

if verbose:
print(f"isolated labels: no more than {iso_threshold} batches per label")

Expand Down
Loading

0 comments on commit 9c5a936

Please sign in to comment.