Skip to content

Commit

Permalink
start adding JTK
Browse files Browse the repository at this point in the history
  • Loading branch information
Bribak committed Dec 4, 2023
1 parent 19bf1bd commit 436d1ed
Show file tree
Hide file tree
Showing 7 changed files with 607 additions and 14 deletions.
47 changes: 47 additions & 0 deletions build/lib/glycowork/glycan_data/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pandas as pd
import re
import os
Expand Down Expand Up @@ -188,6 +189,52 @@ def multireplace(string, remove_dic):
return string


def fast_two_sum(a, b):
"""Assume abs(a) >= abs(b)"""
x = int(a) + int(b)
y = b - (x - int(a))
return [x] if y == 0 else [x, y]


def two_sum(a, b):
"""For unknown order of a and b"""
x = int(a) + int(b)
y = (a - (x - int(b))) + (b - (x - int(a)))
return [x] if y == 0 else [x, y]


def expansion_sum(*args):
"""For the expansion sum of floating points"""
g = sorted(args, reverse = True)
q, *h = fast_two_sum(np.array(g[0]), np.array(g[1]))
for val in g[2:]:
z = two_sum(q, np.array(val))
q, *extra = z
if extra:
h += extra
return [h, q] if h else q


def hlm(z):
"""Hodges-Lehmann estimator of the median"""
z = np.array(z)
zz = np.add.outer(z, z)
zz = zz[np.tril_indices(len(z))]
return np.median(zz) / 2


def update_cf_for_m_n(m, n, MM, cf):
"""Constructs cumulative frequency table for experimental parameters defined in the function 'jtkinit'"""
P = min(m + n, MM)
for t in range(n + 1, P + 1): # Zero-based offset t
for u in range(MM, t - 1, -1): # One-based descending index u
cf[u] = expansion_sum(cf[u], -cf[u - t]) # Shewchuk algorithm
Q = min(m, MM)
for s in range(1, Q + 1): # Zero-based offset s
for u in range(s, MM + 1): # One-based descending index u
cf[u] = expansion_sum(cf[u], cf[u - s]) # Shewchuk algorithm


def build_custom_df(df, kind = 'df_species'):
"""creates custom df from df_glycan\n
| Arguments:
Expand Down
35 changes: 34 additions & 1 deletion build/lib/glycowork/motif/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sklearn.decomposition import PCA

from glycowork.glycan_data.loader import lib, df_species, unwrap, motif_list
from glycowork.motif.processing import cohen_d, mahalanobis_distance, mahalanobis_variance, variance_stabilization, impute_and_normalize, variance_based_filtering
from glycowork.motif.processing import cohen_d, mahalanobis_distance, mahalanobis_variance, variance_stabilization, impute_and_normalize, variance_based_filtering, jtkdist, jtkinit, MissForest, jtkx
from glycowork.motif.annotate import annotate_dataset, quantify_motifs, link_find, create_correlation_network
from glycowork.motif.graph import subgraph_isomorphism

Expand Down Expand Up @@ -849,3 +849,36 @@ def get_time_series(df, impute = True, motifs = False, feature_set = ['known', '
res = pd.DataFrame(res, columns = ['Glycan', 'Change', 'p-val'])
res['corr p-val'] = multipletests(res['p-val'], method = 'fdr_bh')[1]
return res.sort_values(by = 'corr p-val')


def get_jtk(df, timepoints, replicates, periods, interval, motifs = False, feature_set = ['known', 'exhaustive', 'terminal']):
"""Wrapper function running the analysis \n
| Arguments:
| :-
| df (pd.DataFrame): A dataframe containing data for analysis.
| (column 0 = molecule IDs, then arranged in groups and by ascending timepoints)
| timepoints (int): number of timepoints in the experiment.
| replicates (int): number of replicates per timepoints.
| periods (int): number of timepoints per cycle.
| interval (int): units of time (Arbitrary units) between experimental timepoints.
| motifs (bool): a flag for running structural of motif-based analysis (True = run motif analysis); default:False.
| feature_set (list): which feature set to use for annotations, add more to list to expand; default is ['exhaustive','known']; options are: 'known' (hand-crafted glycan features), 'graph' (structural graph features of glycans), 'exhaustive' (all mono- and disaccharide features), 'terminal' (non-reducing end motifs), and 'chemical' (molecular properties of glycan)\n
| Returns:
| :-
| Returns a pandas dataframe containing the adjusted p-values, and most important waveform parameters for each
| molecule in the analysis.
"""
param_dic = {"GRP_SIZE": [], "NUM_GRPS": [], "MAX": [], "DIMS": [], "EXACT": bool(True),
"VAR": [], "EXV": [], "SDV": [], "CGOOSV": []}
param_dic = jtkdist(timepoints, param_dic, replicates)
param_dic = jtkinit(periods, param_dic, interval, replicates)
mf = MissForest()
df.replace(0, np.nan, inplace = True)
df = mf.fit_transform(df)
if motifs:
df = quantify_motifs(pd.DataFrame(df.iloc[:, 1:], df.iloc[:, 0].values.tolist()), feature_set).T
res = df.apply(jtkx, param_dic = param_dic, axis = 1)
JTK_BHQ = pd.DataFrame(sm.stats.multipletests(res[0], method = 'fdr_bh')[1])
Results = pd.concat([res.iloc[:, 0], JTK_BHQ, res.iloc[:, 1:]], axis = 1)
Results.columns = ['Molecule_Name', 'BH_Q_Value', 'Adjusted_P_value', 'Period_Length', 'Lag_Phase', 'Amplitude']
return Results
220 changes: 218 additions & 2 deletions build/lib/glycowork/motif/processing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pandas as pd
import numpy as np
import copy
import math
import re
from functools import wraps
from collections import defaultdict
from sklearn.ensemble import RandomForestRegressor
from sklearn.base import BaseEstimator
from glycowork.glycan_data.loader import unwrap, multireplace, find_nth, find_nth_reverse, linkages, Hex, HexNAc, dHex, Sia, HexA, Pen
from glycowork.glycan_data.loader import unwrap, multireplace, find_nth, find_nth_reverse, linkages, Hex, HexNAc, dHex, Sia, HexA, Pen, hlm, update_cf_for_m_n
rng = np.random.default_rng(42)


Expand Down Expand Up @@ -541,7 +542,8 @@ def replacement(match):
'axxxxh-1x_1-5_2*NCC/3=O': 'HexNAc?', 'axxxxh-1x_1-5': 'Hex?', 'a2112h-1b_1-4': 'Galfb',
'a2122h-1x_1-5_2*NCC/3=O_6*OSO/3=O/3=O': 'GlcNAc6Sb', 'a2112h-1x_1-5_2*NCC/3=O': 'GalNAc?',
'axxxxh-1a_1-5_2*NCC/3=O': 'HexNAca', 'Aad21122h-2a_2-6_4*OCC/3=O_5*NCC/3=O': 'Neu4Ac5Aca',
'a2112h-1b_1-5_4*OSO/3=O/3=O': 'Gal4Sb', 'a2122h-1b_1-5_2*NCC/3=O_3*OSO/3=O/3=O': 'GlcNAc3Sb'
'a2112h-1b_1-5_4*OSO/3=O/3=O': 'Gal4Sb', 'a2122h-1b_1-5_2*NCC/3=O_3*OSO/3=O/3=O': 'GlcNAc3Sb',
'a2112h-1b_1-5_2*NCC/3=O_4*OSO/3=O/3=O': 'GalNAc4Sb', 'a2122A-1x_1-5_?*OSO/3=O/3=O': 'GlcAOS?'
}
parts = wurcs.split('/')
topology = parts[-1].split('_')
Expand Down Expand Up @@ -1060,3 +1062,217 @@ def variance_based_filtering(df, min_feature_variance = 0.01):
variable_features = feature_variances[feature_variances > min_feature_variance].index
# Subsetting df to only include features with enough variance
return df.loc[variable_features]


def jtkdist(timepoints, param_dic, reps = 1, normal = False):
"""Precalculates all possible JT test statistic permutation probabilities for reference later, speeding up the
| analysis. Calculates the exact null distribution using thr Harding algorithm.\n
| Arguments:
| :-
| timepoints (int): number of timepoints within the experiment.
| param_dic (dict): dictionary carrying around the parameter values
| reps (int): number of replicates within each timepoint.
| normal (bool): a flag for normal approximation if maximum possible negative log p-value is too large.\n
| Returns:
| :-
| Returns statistical values, added to 'param_dic'.
"""
timepoints = timepoints if isinstance(timepoints, int) else timepoints.sum()
tim = np.full(timepoints, reps) if reps != timepoints else reps # Support for unbalanced replication (unequal replicates in all groups)
maxnlp = gammaln(np.sum(tim)) - np.sum(np.log(np.arange(1, np.max(tim)+1)))
limit = math.log(float('inf'))
normal = normal or (maxnlp > limit - 1) # Switch to normal approximation if maxnlp is too large
lab = []
nn = sum(tim) # Number of data values (Independent of period and lag)
M = (nn ** 2 - np.sum(np.square(tim))) / 2 # Max possible jtk statistic
param_dic.update({"GRP_SIZE": tim, "NUM_GRPS": len(tim), "NUM_VALS": nn,
"MAX": M, "DIMS": [int(nn * (nn - 1) / 2), 1]})
if normal:
param_dic["VAR"] = (nn ** 2 * (2 * nn + 3) - np.sum(np.square(tim) * (2 * t + 3) for t in tim)) / 72 # Variance of JTK
param_dic["SDV"] = math.sqrt(param_dic["VAR"]) # Standard deviation of JTK
param_dic["EXV"] = M / 2 # Expected value of JTK
param_dic["EXACT"] = False
MM = int(M // 2) # Mode of this possible alternative to JTK distribution
cf = [1] * (MM + 1) # Initial lower half cumulative frequency (cf) distribution
size = sorted(tim) # Sizes of each group of known replicate values, in ascending order for fastest calculation
k = len(tim) # Number of groups of replicates
N = [size[k-1]]
if k > 2:
for i in range(k - 1, 1, -1): # Count permutations using the Harding algorithm
N.insert(0, (size[i] + N[0]))
for m, n in zip(size[:-1], N):
update_cf_for_m_n(m, n, MM, cf)
cf = np.array(cf)
# cf now contains the lower half cumulative frequency distribution
# append the symmetric upper half cumulative frequency distribution to cf
if M % 2: # jtkcf = upper-tail cumulative frequencies for all integer jtk
jtkcf = np.concatenate((cf, 2 * cf[MM] - cf[:MM][::-1], [2 * cf[MM]]))[::-1]
else:
jtkcf = np.concatenate((cf, cf[MM - 1] + cf[MM] - cf[:MM-1][::-1], [cf[MM - 1] + cf[MM]]))[::-1]
ajtkcf = list((jtkcf[i - 1] + jtkcf[i]) / 2 for i in range(1, len(jtkcf))) # interpolated cumulative frequency values for all half-intgeger jtk
cf = [ajtkcf[(j - 1) // 2] if j % 2 == 0 else jtkcf[j // 2] for j in [i for i in range(1, 2 * int(M) + 2)]]
param_dic["CP"] = [c / jtkcf[0] for c in cf] # all upper-tail p-values
return param_dic


def jtkinit(periods, param_dic, interval = 1, replicates = 1):
"""Defines the parameters of the simulated sine waves for reference later.\n
| Each molecular species within the analysis is matched to the optimal wave defined here, and the parameters
| describing that wave are attributed to the molecular species.\n
| Arguments:
| :-
| periods (list): the possible periods of rhytmicity in the biological data (valued as 'number of timepoints').
| (note: periods can accept multiple values (ie, you can define circadian rhythms as between 22, 24, 26 hours))
| param_dic (dict): dictionary carrying around the parameter values
| interval (int): the number of units of time (arbitrary) between each experimental timepoint.
| replicates (int): number of replicates within each group.\n
| Returns:
| :-
| Returns values describing waveforms, added to 'param_dic'.
"""
param_dic["INTERVAL"] = interval
if len(periods) > 1:
param_dic["PERIODS"] = list(periods)
else:
param_dic["PERIODS"] = list(periods)
param_dic["PERFACTOR"] = np.concatenate([np.repeat(i, ti) for i, ti in enumerate(periods, start = 1)])
tim = np.array(param_dic["GRP_SIZE"])
timepoints = int(param_dic["NUM_GRPS"])
timerange = np.arange(timepoints) # Zero-based time indices
param_dic["SIGNCOS"] = np.zeros((periods[0], ((math.floor(timepoints / (periods[0]))*int(periods[0]))* replicates)), dtype = int)
for i, period in enumerate(periods):
time2angle = np.array([(2*round(math.pi, 4))/period]) # convert time to angle using an ~pi value
theta = timerange*time2angle # zero-based angular values across time indices
cos_v = np.cos(theta) # unique cosine values at each time point
cos_r = np.repeat(rankdata(cos_v), np.max(tim)) # replicated ranks of unique cosine values
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).astype(int)
lower_tri = []
for col in range(len(cgoos)):
for row in range(col + 1, len(cgoos)):
lower_tri.append(cgoos[row, col])
cgoos = np.array(lower_tri)
cgoosv = np.array(cgoos).reshape(param_dic["DIMS"])
param_dic["CGOOSV"] = []
param_dic["CGOOSV"].append(np.zeros((cgoos.shape[0], period)))
param_dic["CGOOSV"][i][:, 0] = cgoosv[:, 0]
cycles = math.floor(timepoints / period)
jrange = np.arange(cycles * period)
cos_s = np.sign(cos_v)[jrange]
cos_s = np.repeat(cos_s, (tim[jrange]))
if reps == 1:
param_dic["SIGNCOS"][:, i] = cos_s
else:
param_dic["SIGNCOS"][i] = cos_s
for j in range(1, period): # One-based half-integer lag index j
delta_theta = j * time2angle / 2 # Angles of half-integer lags
cos_v = np.cos(theta + delta_theta) # Cycle left
cos_r = np.concatenate([np.repeat(val, num) for val, num in zip(rankdata(cos_v), tim)]) # Phase-shifted replicated ranks
cgoos = np.sign(np.subtract.outer(cos_r, cos_r)).T
mask = np.triu(np.ones(cgoos.shape), k = 1).astype(bool)
mask[np.diag_indices(mask.shape[0])] = False
cgoos = cgoos[mask]
cgoosv = cgoos.reshape(param_dic["DIMS"])
matrix_i = param_dic["CGOOSV"][i]
matrix_i[:, j] = cgoosv.flatten()
param_dic["CGOOSV[i]"] = matrix_i
cos_v = cos_v.flatten()
cos_s = np.sign(cos_v)[jrange]
cos_s = np.repeat(cos_s, (tim[jrange]))
if reps == 1:
param_dic["SIGNCOS"][:, j] = cos_s
else:
param_dic["SIGNCOS"][j] = cos_s
return param_dic


def jtkstat(z, param_dic):
"""Determines the JTK statistic and p-values for all model phases, compared to expression data.\n
| Arguments:
| :-
| z (pd.DataFrame): expression data for a molecule ordered in groups, by timepoint.
| param_dic (dict): a dictionary containing parameters defining model waveforms.\n
| Returns:
| :-
| Returns an updated parameter dictionary where the appropriate model waveform has been assigned to the
| molecules in the analysis.
"""
param_dic["CJTK"] = []
M = param_dic["MAX"]
z = np.array(z)
foosv = np.sign(np.subtract.outer(z, z)).T # Due to differences in the triangle indexing of R / Python we need to transpose and select upper triangle rather than the lower triangle
mask = np.triu(np.ones(foosv.shape), k = 1).astype(bool) # Additionally, we need to remove the middle diagonal from the tri index
mask[np.diag_indices(mask.shape[0])] = False
foosv = foosv[mask].reshape(param_dic["DIMS"])
for i in range(param_dic["PERIODS"][0]):
cgoosv = param_dic["CGOOSV"][0][:, i]
S = np.nansum(np.diag(foosv * cgoosv))
jtk = (abs(S) + M) / 2 # Two-tailed JTK statistic for this lag and distribution
if S == 0:
param_dic["CJTK"].append([1, 0, 0])
elif param_dic.get("EXACT", False):
jtki = 1 + 2 * int(jtk) # index into the exact upper-tail distribution
p = 2 * param_dic["CP"][jtki-1]
param_dic["CJTK"].append([p, S, S / M])
else:
p = 2 * norm.cdf(-(jtk - 0.5), -param_dic["EXV"], param_dic["SDV"])
param_dic["CJTK"].append([p, S, S / M]) # include tau = s/M for this lag and distribution
return param_dic


def jtkx(z, param_dic, ampci = False):
"""Deployment of jtkstat for repeated use, and parameter extraction\n
| Arguments:
| :-
| z (pd.dataframe): expression data ordered in groups, by timepoint.
| param_dic (dict): a dictionary containing parameters defining model waveforms.
| ampci (bool): flag for calculating amplitude confidence interval (TRUE = compute); default=False.\n
| Returns:
| :-
| Returns an updated parameter dictionary containing the optimal waveform parameters for each molecular species.
"""
param_dic = jtkstat(z, param_dic) # Calculate p and S for all phases
pvals = [cjtk[0] for cjtk in param_dic["CJTK"]] # Exact two-tailed p values for period/phase combos
padj = multipletests(pvals, method = 'fdr_bh')[1]
JTK_ADJP = min(padj) # Global minimum adjusted p-value
def groupings(padj, param_dic):
d = defaultdict(list)
for i, value in enumerate(padj):
key = param_dic["PERFACTOR"][i]
d[key].append(value)
return dict(d)
dpadj = groupings(padj, param_dic)
padj = np.array(pd.DataFrame(dpadj.values()).T)
minpadj = [padj[i].min() for i in range(0, np.shape(padj)[1])] # Minimum adjusted p-values for each period
if len(param_dic["PERIODS"]) > 1:
pers_index = np.where(JTK_ADJP == minpadj)[0] # indices of all optimal periods
else:
pers_index = 0
pers = param_dic["PERIODS"][int(pers_index)] # all optimal periods
padj_values = padj[pers_index]
lagis = np.where(padj == JTK_ADJP)[0] # list of optimal lag indice for each optimal period
best_results = {'bestper': 0, 'bestlag': 0, 'besttau': 0, 'maxamp': 0, 'maxamp_ci': 2, 'maxamp_pval': 0}
sc = np.transpose(param_dic["SIGNCOS"])
w = (z[:len(sc)] - hlm(z[:len(sc)])) * math.sqrt(2)
for i in range(abs(pers)):
for lagi in lagis:
S = param_dic["CJTK"][lagi][1]
s = np.sign(S) if S != 0 else 1
lag = (pers + (1 - s) * pers / 4 - lagi / 2) % pers
signcos = sc[:, lagi]
tmp = s * w * sc[:, lagi]
amp = hlm(tmp) # Allows missing values
if ampci:
jtkwt = pd.DataFrame(wilcoxon(tmp[np.isfinite(tmp)], zero_method = 'wilcox', correction = False,
alternatives = 'two-sided', mode = 'exact'))
amp = jtkwt['confidence_interval'].median() # Extract estimate (median) from the conf. interval
best_results['maxamp_ci'] = jtkwt['confidence_interval'].values
best_results['maxamp_pval'] = jtkwt['pvalue'].values
if amp > best_results['maxamp']:
best_results.update({'bestper': pers, 'bestlag': lag, 'besttau': [abs(param_dic["CJTK"][lagi][2])], 'maxamp': amp})
JTK_PERIOD = param_dic["INTERVAL"] * best_results['bestper']
JTK_LAG = param_dic["INTERVAL"] * best_results['bestlag']
JTK_AMP = float(max(0, best_results['maxamp']))
JTK_TAU = best_results['besttau']
JTK_AMP_CI = best_results['maxamp_ci']
JTK_AMP_PVAL = best_results['maxamp_pval']
return pd.Series([JTK_ADJP, JTK_PERIOD, JTK_LAG, JTK_AMP])
Loading

0 comments on commit 436d1ed

Please sign in to comment.