Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seurat v3 VST highly variable gene method #993

Closed
adamgayoso opened this issue Jan 14, 2020 · 7 comments · Fixed by #1204
Closed

Seurat v3 VST highly variable gene method #993

adamgayoso opened this issue Jan 14, 2020 · 7 comments · Fixed by #1204

Comments

@adamgayoso
Copy link
Member

adamgayoso commented Jan 14, 2020

I find this method to be the most conceptually straightforward and it gives great results in my tests.

I have a rough implementation in python. I see that making a PR would be more involved as the code relies on log-transformed data, while the Seurat method should be on the raw counts.

I also understand that adding rpy2 to scanpy could be a bit challenging so I have a close approximation with the stats models library.

import statsmodels.api as sm
def seurat_v3_highly_variable_genes(adata, n_top_genes=4000, use_lowess=False):
    norm_gene_vars = []
    del_batch = False
    if "batch" not in adata.obs_keys():
        del_batch = True
        adata.obs["batch"] = np.zeros((adata.X.shape[0]))
    for b in np.unique(adata.obs["batch"]):
        var = adata[adata.obs["batch"] == b].X.var(0)
        print(var.shape)
        mean = adata[adata.obs["batch"] == b].X.mean(0)
        estimat_var = np.zeros((adata.X.shape[1]))

        y = np.log10(var)
        x = np.log10(mean)
        if use_lowess is True:
            lowess = sm.nonparametric.lowess
            # output is sorted by x
            v = lowess(y, x, frac=0.15)
            estimat_var[np.argsort(x)] = v[:, 1]
        else:
            estimat_var = loess(y, x)

        norm_values = (adata[adata.obs["batch"] == b].X - mean) / np.sqrt(10 ** estimat_var)
        # as in seurat paper, clip max values
        norm_values = np.clip(
            norm_values, None, np.sqrt(np.sum(adata.obs["batch"] == b))
        )
        norm_gene_var = norm_values.var(0)
        norm_gene_vars.append(norm_gene_var.reshape(1, -1))

    norm_gene_vars = np.concatenate(norm_gene_vars, axis=0)
    ranked_norm_gene_vars = np.argsort(np.argsort(norm_gene_vars, axis=1), axis=1)
    median_norm_gene_vars = np.median(norm_gene_vars, axis=0)
    median_ranked = np.median(ranked_norm_gene_vars, axis=0)

    num_batches_high_var = np.sum(
        ranked_norm_gene_vars >= (adata.X.shape[1] - n_top_genes), axis=0
    )
    df = pd.DataFrame(index=np.array(adata.var_names))
    df["highly_variable_n_batches"] = num_batches_high_var
    df["highly_variable_median_rank"] = median_ranked

    df["highly_variable_median_variance"] = median_norm_gene_vars
    df.sort_values(
        ["highly_variable_n_batches", "highly_variable_median_rank"],
        ascending=False,
        na_position="last",
        inplace=True,
    )
    df["highly_variable"] = False
    df.loc[:n_top_genes, "highly_variable"] = True
    df = df.loc[adata.var_names]

    if del_batch is True:
        del adata.obs["batch"]

    adata.var["highly_variable"] = df["highly_variable"].values
    adata.var["highly_variable_n_batches"] = df["highly_variable_n_batches"].values
    adata.var["highly_variable_median_variance"] = df[
        "highly_variable_median_variance"
    ].values


def loess(y, x, span=0.3):
    from rpy2.robjects import r
    import rpy2.robjects as robjects

    a, b = robjects.FloatVector(x), robjects.FloatVector(y)
    df = robjects.DataFrame({"a": a, "b": b})
    loess_fit = r.loess("b ~ a", data=df, span=span)

    return np.array(loess_fit[loess_fit.names.index("fitted")])
@ivirshup
Copy link
Member

@LuckyMD @gokceneraslan, I think would be nice to have. What do you think? I think it would be fine that it works on counts as long as that's very explicit in the documentation.

@LuckyMD
Copy link
Contributor

LuckyMD commented Apr 17, 2020

Agreed!

On another note, we currently lack a method for HVG selection that works on scaled/regressed out data. Could rewrite regress_out to not just output the residuals, but also add the intercept again.

@adamgayoso
Copy link
Member Author

I'm glad you all are considering adding this. I updated the implementation to work with sparse counts.

def seurat_v3_highly_variable_genes(
    adata, n_top_genes: int = 4000, batch_key: str = "batch"
):
    """ An adapted implementation of the "vst" feature selection in Seurat v3.

        The major differences are that we use lowess insted of loess.

        For further details of the sparse arithmetic see https://www.overleaf.com/read/ckptrbgzzzpg

        :param n_top_genes: How many variable genes to return
        :param batch_key: key in adata.obs that contains batch info. If None, do not use batch info

    """

    from scanpy.preprocessing._utils import _get_mean_var
    from scanpy.preprocessing._distributed import materialize_as_ndarray

    lowess = sm.nonparametric.lowess

    if batch_key is None:
        batch_correction = False
        batch_key = "batch"
        adata.obs[batch_key] = pd.Categorical(np.zeros((adata.X.shape[0])).astype(int))
    else:
        batch_correction = True

    norm_gene_vars = []
    for b in np.unique(adata.obs[batch_key]):

        mean, var = materialize_as_ndarray(
            _get_mean_var(adata[adata.obs[batch_key] == b].X)
        )
        not_const = var > 0
        estimat_var = np.zeros((adata.X.shape[1]))

        y = np.log10(var[not_const])
        x = np.log10(mean[not_const])
        # output is sorted by x
        v = lowess(y, x, frac=0.15)
        estimat_var[not_const][np.argsort(x)] = v[:, 1]

        # get normalized variance
        reg_std = np.sqrt(10 ** estimat_var)
        batch_counts = adata[adata.obs[batch_key] == b].X.copy()
        # clip large values as in Seurat
        N = np.sum(adata.obs["batch"] == b)
        vmax = np.sqrt(N)
        clip_val = reg_std * vmax + mean
        # could be something faster here
        for g in range(batch_counts.shape[1]):
            batch_counts[:, g][batch_counts[:, g] > vmax] = clip_val[g]

        if sp_sparse.issparse(batch_counts):
            squared_batch_counts_sum = np.array(batch_counts.power(2).sum(axis=0))
            batch_counts_sum = np.array(batch_counts.sum(axis=0))
        else:
            squared_batch_counts_sum = np.square(batch_counts).sum(axis=0)
            batch_counts_sum = batch_counts.sum(axis=0)

        norm_gene_var = (1 / ((N - 1) * np.square(reg_std))) * (
            (N * np.square(mean))
            + squared_batch_counts_sum
            - 2 * batch_counts_sum * mean
        )
        norm_gene_vars.append(norm_gene_var.reshape(1, -1))

    norm_gene_vars = np.concatenate(norm_gene_vars, axis=0)
    # argsort twice gives ranks
    ranked_norm_gene_vars = np.argsort(np.argsort(norm_gene_vars, axis=1), axis=1)
    median_norm_gene_vars = np.median(norm_gene_vars, axis=0)
    median_ranked = np.median(ranked_norm_gene_vars, axis=0)

    num_batches_high_var = np.sum(
        ranked_norm_gene_vars >= (adata.X.shape[1] - n_top_genes), axis=0
    )
    df = pd.DataFrame(index=np.array(adata.var_names))
    df["highly_variable_nbatches"] = num_batches_high_var
    df["highly_variable_median_rank"] = median_ranked

    df["highly_variable_median_variance"] = median_norm_gene_vars
    df.sort_values(
        ["highly_variable_nbatches", "highly_variable_median_rank"],
        ascending=False,
        na_position="last",
        inplace=True,
    )
    df["highly_variable"] = False
    df.loc[:n_top_genes, "highly_variable"] = True
    df = df.loc[adata.var_names]

    adata.var["highly_variable"] = df["highly_variable"].values
    if batch_correction is True:
        batches = adata.obs[batch_key].cat.categories
        adata.var["highly_variable_nbatches"] = df["highly_variable_nbatches"].values
        adata.var["highly_variable_intersection"] = df[
            "highly_variable_nbatches"
        ] == len(batches)
    adata.var["highly_variable_median_rank"] = df["highly_variable_median_rank"].values
    adata.var["highly_variable_median_variance"] = df[
        "highly_variable_median_variance"
    ].values

@ivirshup
Copy link
Member

@adamgayoso could you open a PR for this?

@adamgayoso
Copy link
Member Author

yes, I'll get to it next week. It didn't seem there was a straightforward way to integrate with the existing implementation given the filtering criterion is different, but I'll try my best.

@cchrysostomou
Copy link

@adamgayoso, I have a question regarding the implementation of Seurat v3 HVG and am not sure if this is the correct thread (it's probably not). My question is regarding the final step where the function reports, variances_norm or norm_gene_var. Based on the description here, https://www.overleaf.com/project/5e7e320564f7d4000175d082, the norm_gene_var function computes the variance of the transformed values assuming that the mean of the zscores is 0. I guess my question is, post clipping values to a maximum, I think the mean of the transformed values might not be 0 anymore so if you were just to perform, var(transformed values), it will not equal the same value as variances_norm equation for the sparse approach. Reading through the referenced paper provided (Stuart 2019) its not clear whether they perform the variance of zscores post clipping, or with the assumption that mean zscore is 0 preclipping.

If this is not relevant, please feel free to ask me to delete this comment.

@adamgayoso
Copy link
Member Author

@cchrysostomou Indeed the mean will no longer be zero, I was merely reimplementing exactly what was done in Seurat, and we have tests to show in the single batch case that we get the same exact genes. No need to delete this comment.

I suppose you can think of it as the second moment instead of the variance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants