-
Notifications
You must be signed in to change notification settings - Fork 10
/
mbgan_inference_casectrl.py
107 lines (80 loc) · 3.44 KB
/
mbgan_inference_casectrl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import pickle
import numpy as np
import pandas as pd
from scipy.stats import describe
from keras.models import load_model
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
SEED = 256
TOL = 1e-4
GENERATOR_CASE_PATH = os.path.join('models', 'stool_2_case_generator.h5')
GENERATOR_CTRL_PATH = os.path.join('models', 'stool_2_ctrl_generator.h5')
def shannon_entropy(x, tol=0.):
return -np.sum(np.where(x > tol, x * np.log(x), 0), axis=-1)
def get_sparsity(x, tol=0.):
return np.sum(x <= tol, axis=-1)/x.shape[1]
def expand_phylo(taxa_list):
""" expand taxa to higher order.
adj_matrix, taxa_indices = expand_phylo(colnames)
"""
memo, Ntaxa, adj, = {}, len(taxa_list), []
for i, taxa in enumerate(taxa_list):
memo[taxa] = i
trunc = taxa.split('|')
for j in range(len(trunc)):
p = '|'.join(trunc[:j+1])
if p not in memo:
memo[p] = Ntaxa
Ntaxa += 1
adj.append((memo[taxa], memo[p]))
return adj, dict((v, k) for k, v in memo.items())
def adjmatrix_to_dense(x, shape, val=1):
mask = np.zeros(shape)
x = np.array(x).transpose()
mask[tuple(x)] = val
return mask
def simulate(model, n_samples=1000, transform=None, seed=None):
np.random.seed(seed)
latent_dim = model.inputs[0].shape[-1]
z = np.random.normal(0, 1, (n_samples, latent_dim))
res = model.predict(z)
if transform is not None:
res = transform(res)
return res
if __name__ == "__main__":
## Load raw dataset
raw_data = pickle.load(open("./data/raw_data.pkl", 'rb'))
dataset = raw_data.iloc[:,1:].values/100.
labels = raw_data["group"].values
taxa_list = raw_data.columns[1:]
data_o_case = dataset[labels == 'case']
data_o_ctrl = dataset[labels == 'ctrl']
## Generate data
generator_case = load_model(GENERATOR_CASE_PATH)
generator_ctrl = load_model(GENERATOR_CTRL_PATH)
data_g_case = simulate(generator_case, n_samples=1000, seed=SEED)
data_g_ctrl = simulate(generator_ctrl, n_samples=1000, seed=SEED)
## Show data statistics
print("Sparsity")
print(pd.DataFrame(
[describe(get_sparsity(data_o_ctrl, TOL)),
describe(get_sparsity(data_g_ctrl, TOL)),
describe(get_sparsity(data_o_case, TOL)),
describe(get_sparsity(data_g_case, TOL)),],
index=['Original ctrl', 'GAN ctrl', 'Original case', 'GAN case']))
print("Shannon Entropy")
print(pd.DataFrame(
[describe(shannon_entropy(data_o_ctrl)),
describe(shannon_entropy(data_g_ctrl)),
describe(shannon_entropy(data_o_case)),
describe(shannon_entropy(data_g_case)),],
index=['Original ctrl', 'GAN ctrl', 'Original case', 'GAN case']))
## Save simulated species level data
pd.DataFrame(data_g_case, columns=taxa_list).to_csv("./outputs/stools_2_case_species.csv")
pd.DataFrame(data_g_ctrl, columns=taxa_list).to_csv("./outputs/stools_2_ctrl_species.csv")
## Save simulated phylogenetic data
adj_matrix, taxa_indices = expand_phylo(taxa_list)
tf_matrix = adjmatrix_to_dense(adj_matrix, shape=(len(taxa_list), len(taxa_indices)))
pd.DataFrame(np.dot(data_g_case, tf_matrix), columns=taxa_indices).to_csv("./outputs/stools_2_case_phylo.csv")
pd.DataFrame(np.dot(data_g_ctrl, tf_matrix), columns=taxa_indices).to_csv("./outputs/stools_2_ctrl_phylo.csv")