-
Notifications
You must be signed in to change notification settings - Fork 2
/
emb_Gen.py
55 lines (44 loc) · 1.62 KB
/
emb_Gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Modified based on codes implemented by Lucas Hu et al
doi:10.5281/zenodo.1408472
"""
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import scipy.sparse as sp
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import pickle
from node2vec import node2vec
from gensim.models import Word2Vec
from node2vec.preprocessing import mask_test_edges
network_dir = './GraphPickle/HPO-Orphanet.pkl'
with open(network_dir, 'rb') as f:
adj, features = pickle.load(f)
g = nx.Graph(adj)
np.random.seed(0)
adj_sparse = nx.to_scipy_sparse_matrix(g)
# Perform train-test split
adj_train, train_edges, train_edges_false, val_edges, val_edges_false, \
test_edges, test_edges_false = mask_test_edges(adj_sparse, test_frac=.3, val_frac=.1)
g_train = nx.from_scipy_sparse_matrix(adj_train)
P = 1 # Return hyperparameter
Q = 0.05 # In-out hyperparameter
WINDOW_SIZE = 10 # Context size for optimization
NUM_WALKS = 10 # Number of walks per source
WALK_LENGTH = 5 # Length of walk per source
DIMENSIONS = 128 # Embedding dimension
DIRECTED = False # Graph directed/undirected
WORKERS = 8 # Num. parallel workers
ITER = 1 # SGD epochs
# Preprocessing, generate walks
g_n2v = node2vec.Graph(g_train, DIRECTED, P, Q)
g_n2v.preprocess_transition_probs()
walks = g_n2v.simulate_walks(NUM_WALKS, WALK_LENGTH)
walks = [list(map(str, walk)) for walk in walks]
print(walks)
# Train skip-gram model
model = Word2Vec(walks, size=DIMENSIONS, window=WINDOW_SIZE, min_count=0, sg=1, workers=WORKERS, iter=ITER)
emb_mappings = model.wv
model.wv.save_word2vec_format('HPOEmb-Orphanet.emd')