-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathkarateclub.py
59 lines (45 loc) · 1.75 KB
/
karateclub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import networkx as nx
import pandas as pd
import imageio
import matplotlib.pyplot as plt
import tqdm
import pathlib
from fastrec import GraphRecommender
def animate(labelsnp,all_embeddings,mask):
labelsnp = labelsnp[mask]
for i,embedding in enumerate(tqdm.tqdm(all_embeddings)):
data = embedding[mask]
fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
plt.title('Epoch {}'.format(i))
colormap = ['r' if l=='Administrator' else 'b' for l in labelsnp]
plt.scatter(data[:,0],data[:,1], c=colormap)
ax.annotate('Administrator',(data[0,0],data[0,1]))
ax.annotate('Instructor',(data[33,0],data[33,1]))
plt.savefig('./ims/{n}.png'.format(n=i))
plt.close()
imagep = pathlib.Path('./ims/')
images = imagep.glob('*.png')
images = list(images)
images.sort(key=lambda x : int(str(x).split('/')[-1].split('.')[0]))
with imageio.get_writer('./animation.gif', mode='I') as writer:
for image in images:
data = imageio.imread(image.__str__())
writer.append_data(data)
if __name__=='__main__':
g = nx.karate_club_graph()
nodes = list(g.nodes)
e1,e2 = zip(*g.edges)
attributes = pd.read_csv('./karate_attributes.csv')
sage = GraphRecommender(2,distance='l2')
sage.add_nodes(nodes)
sage.add_edges(e1,e2)
sage.add_edges(e2,e1)
sage.update_labels(attributes.community)
epochs, batch_size = 150, 15
_,_,all_embeddings = sage.train(epochs, batch_size, unsupervised = True, learning_rate=1e-2,
test_every_n_epochs=10, return_intermediate_embeddings=True)
animate(sage.labels,all_embeddings,sage.entity_mask)
print(sage.query_neighbors([0,33],k=5))
sage.start_api()