Skip to content

Commit

Permalink
add new url_path and change the download func
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma committed Aug 6, 2023
1 parent 18c98a9 commit f408a51
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 164 deletions.
77 changes: 43 additions & 34 deletions pygmtools/jittor_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,18 @@ def execute(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_


pca_gm_pretrain_path = {
'voc': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1k4eBJ869uX7sN9TVTe67-8ZKRffpeBu8',
'https://www.dropbox.com/scl/fi/gc7ekhxdeump5znzv8nnz/pca_gm_voc_jittor.pt?rlkey=h9fe2d3cfn4r1fumvaqjjia16&dl=1',
'112bb91bd0ccc573c3a936c49416d79e'),
'willow': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=15R3mdOR99g1LuSyv2IikRmlvy06ub7GQ',
'https://www.dropbox.com/scl/fi/1irrb63jdz30m2ebj3lhy/pca_gm_willow_jittor.pt?rlkey=njkwysy6eh89wa4jgrl54v767&dl=1',
'72f4decf48eb5e00933699518563035a'),
'voc-all': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=17QvlZRAFcPBslaMCax9BVmQpoFMUWv5I',
'https://www.dropbox.com/scl/fi/yl8vymg3rc52n4wxr50gh/pca_gm_voc-all_jittor.pt?rlkey=niypnhmsve6md495j59psqs37&dl=1',
'65cdf9ab437fa37c18eac147cb490c8f')
'voc': (['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1k4eBJ869uX7sN9TVTe67-8ZKRffpeBu8',
'https://www.dropbox.com/scl/fi/gc7ekhxdeump5znzv8nnz/pca_gm_voc_jittor.pt?rlkey=h9fe2d3cfn4r1fumvaqjjia16&dl=1'],
'112bb91bd0ccc573c3a936c49416d79e'),
'willow': (['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_willow_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=15R3mdOR99g1LuSyv2IikRmlvy06ub7GQ',
'https://www.dropbox.com/scl/fi/1irrb63jdz30m2ebj3lhy/pca_gm_willow_jittor.pt?rlkey=njkwysy6eh89wa4jgrl54v767&dl=1'],
'72f4decf48eb5e00933699518563035a'),
'voc-all': (['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc-all_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=17QvlZRAFcPBslaMCax9BVmQpoFMUWv5I',
'https://www.dropbox.com/scl/fi/yl8vymg3rc52n4wxr50gh/pca_gm_voc-all_jittor.pt?rlkey=niypnhmsve6md495j59psqs37&dl=1'],
'65cdf9ab437fa37c18eac147cb490c8f')
}


Expand All @@ -949,8 +952,8 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2,
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers)
if pretrain:
if pretrain in pca_gm_pretrain_path:
url, url_alter, md5 = pca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'pca_gm_{pretrain}_jittor.pt', url, md5, url_alter)
url, md5 = pca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'pca_gm_{pretrain}_jittor.pt', url, md5)
_load_model(network, filename)
else:
raise ValueError(f'Unknown pretrain tag. Available tags: {pca_gm_pretrain_path.keys()}')
Expand All @@ -968,13 +971,15 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2,


ipca_gm_pretrain_path = {
'voc': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1B5W83efRL50C1D348xPJHaHoEXpAfKTL',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ipca_gm_voc_jittor.pt',
'voc': (['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_voc_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1B5W83efRL50C1D348xPJHaHoEXpAfKTL',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ipca_gm_voc_jittor.pt'],
'3a6dc7948c75d2e31781847941b5f2f6'),

'willow': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1iHSAY0d7Ufw9slYQjD_dEMkUB8SQM0kO',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ipca_gm_willow_jittor.pt',
'5a1a5b783b9e7ba51579b724a26dccb4'),
'willow': (['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_willow_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1iHSAY0d7Ufw9slYQjD_dEMkUB8SQM0kO',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ipca_gm_willow_jittor.pt'],
'5a1a5b783b9e7ba51579b724a26dccb4'),
}


Expand All @@ -992,8 +997,8 @@ def ipca_gm(feat1, feat2, A1, A2, n1, n2,
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers, cross_iter)
if pretrain:
if pretrain in ipca_gm_pretrain_path:
url, url_alter, md5 = ipca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'ipca_gm_{pretrain}_jittor.pt', url, md5, url_alter)
url, md5 = ipca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'ipca_gm_{pretrain}_jittor.pt', url, md5)
_load_model(network, filename)
else:
raise ValueError(f'Unknown pretrain tag. Available tags: {ipca_gm_pretrain_path.keys()}')
Expand Down Expand Up @@ -1060,12 +1065,14 @@ def execute(self, feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2


cie_pretrain_path = {
'voc': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1jjzbtXne_ppdg7M2jWEpye8piURDVidY',
'https://www.dropbox.com/scl/fi/rual5ozkfrbe3205lzui3/cie_voc_jittor.pt?rlkey=zio8gca7qg8mc5a3murl6mew9&dl=1',
'dc398a5885c5d5894ed6667103d2ff18'),
'willow': ('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=11ftNCYBGnjGpFM3__oTCpBhOBabSU1Rv',
'https://www.dropbox.com/scl/fi/iyigyphvuil8ch7cura0n/cie_willow_jittor.pt?rlkey=44wgqd4njgwmj8qdbe9fjxnez&dl=1',
'bef2c341f605669ed4211e8ff7b1fe0b'),
'voc': (['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_voc_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1jjzbtXne_ppdg7M2jWEpye8piURDVidY',
'https://www.dropbox.com/scl/fi/rual5ozkfrbe3205lzui3/cie_voc_jittor.pt?rlkey=zio8gca7qg8mc5a3murl6mew9&dl=1'],
'dc398a5885c5d5894ed6667103d2ff18'),
'willow': (['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_willow_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=11ftNCYBGnjGpFM3__oTCpBhOBabSU1Rv',
'https://www.dropbox.com/scl/fi/iyigyphvuil8ch7cura0n/cie_willow_jittor.pt?rlkey=44wgqd4njgwmj8qdbe9fjxnez&dl=1'],
'bef2c341f605669ed4211e8ff7b1fe0b'),
}


Expand All @@ -1083,8 +1090,8 @@ def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2,
network = CIE_Net(in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers)
if pretrain:
if pretrain in cie_pretrain_path:
url, url_alter, md5 = cie_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'cie_{pretrain}_jittor.pt', url, md5, url_alter)
url, md5 = cie_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'cie_{pretrain}_jittor.pt', url, md5)
_load_model(network, filename)
else:
raise ValueError(f'Unknown pretrain tag. Available tags: {cie_pretrain_path.keys()}')
Expand Down Expand Up @@ -1141,12 +1148,14 @@ def execute(self, K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau):


ngm_pretrain_path = {
'voc': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ngm_voc_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1_KZQPR6msYsMXupfrAgGgXT-zUXaGtmL',
'1c01a48ee2095b70da270da9d862a8c0'),
'willow': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ngm_willow_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1sLI7iC9kUyWm3xeByHvAMx_Hux8VAuP7',
'c23821751c895f79bbd038fa426ce259'),
'voc': (['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_voc_jittor.pt',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ngm_voc_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1_KZQPR6msYsMXupfrAgGgXT-zUXaGtmL'],
'1c01a48ee2095b70da270da9d862a8c0'),
'willow': (['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_willow_jittor.pt',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/jittor_backend/ngm_willow_jittor.pt',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1sLI7iC9kUyWm3xeByHvAMx_Hux8VAuP7'],
'c23821751c895f79bbd038fa426ce259'),
}


Expand All @@ -1162,9 +1171,9 @@ def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau,
network = NGM_Net(gnn_channels, sk_emb)
if pretrain:
if pretrain in ngm_pretrain_path:
url, url_alter, md5 = ngm_pretrain_path[pretrain]
url, md5 = ngm_pretrain_path[pretrain]
try:
filename = pygmtools.utils.download(f'ngm_{pretrain}_jittor.pt', url, md5, url_alter)
filename = pygmtools.utils.download(f'ngm_{pretrain}_jittor.pt', url, md5)
except:
filename = os.path.dirname(__file__) + f'/temp/ngm_{pretrain}_jittor.pt'
_load_model(network, filename)
Expand Down
79 changes: 44 additions & 35 deletions pygmtools/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,15 +887,18 @@ def forward(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_
return s

pca_gm_pretrain_path = {
'voc':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1En_9f5Zi5rSsS-JTIce7B1BV6ijGEAPd',
'https://www.dropbox.com/s/x79ib1em4cgddqp/pca_gm_voc_numpy.npy?dl=1',
'd85f97498157d723793b8fc1501841ce'),
'willow':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1LAnK6ASYu0CO1fEe6WpvMbt5vskuvwLo',
'https://www.dropbox.com/s/2vo4wpd9467bl5r/pca_gm_willow_numpy.npy?dl=1',
'c32f7c8a7a6978619b8fdbb6ad5b505f'),
'voc-all':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1c_aw4wxEBuY7JFC4Rt8rlcise777n189',
'https://www.dropbox.com/s/6yunsy3gqxfvdyu/pca_gm_voc-all_numpy.npy?dl=1',
'0e2725b3ac51f87f0303bbcfaae5df80')
'voc':(['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1En_9f5Zi5rSsS-JTIce7B1BV6ijGEAPd',
'https://www.dropbox.com/s/x79ib1em4cgddqp/pca_gm_voc_numpy.npy?dl=1'],
'd85f97498157d723793b8fc1501841ce'),
'willow':(['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_willow_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1LAnK6ASYu0CO1fEe6WpvMbt5vskuvwLo',
'https://www.dropbox.com/s/2vo4wpd9467bl5r/pca_gm_willow_numpy.npy?dl=1'],
'c32f7c8a7a6978619b8fdbb6ad5b505f'),
'voc-all':(['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc-all_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1c_aw4wxEBuY7JFC4Rt8rlcise777n189',
'https://www.dropbox.com/s/6yunsy3gqxfvdyu/pca_gm_voc-all_numpy.npy?dl=1'],
'0e2725b3ac51f87f0303bbcfaae5df80')
}

def pca_gm(feat1, feat2, A1, A2, n1, n2,
Expand All @@ -912,8 +915,8 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2,
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers)
if pretrain:
if pretrain in pca_gm_pretrain_path.keys():
url, url_alter, md5 = pca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'pca_gm_{pretrain}_numpy.npy', url, md5, url_alter)
url, md5 = pca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'pca_gm_{pretrain}_numpy.npy', url, md5)
pca_gm_numpy_dict = np.load(filename,allow_pickle=True)
for i in range(network.gnn_layer):
gnn_layer = network.dict['gnn_layer_{}'.format(i)]
Expand Down Expand Up @@ -943,12 +946,14 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2,
return result, network

ipca_gm_pretrain_path = {
'voc':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=13g9iBjXZ804bKo6p8wMQe8yNUZBwVGJj',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ipca_gm_voc_numpy.npy',
'4479a25558780a4b4c9891b4386659cd'),
'willow':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1vq0FqjPhiSR80cu9jk0qMljkC4gSFvQA',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ipca_gm_willow_numpy.npy',
'ada1df350d45cc877f08e12919993345')
'voc':(['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_voc_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=13g9iBjXZ804bKo6p8wMQe8yNUZBwVGJj',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ipca_gm_voc_numpy.npy'],
'4479a25558780a4b4c9891b4386659cd'),
'willow':(['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_willow_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1vq0FqjPhiSR80cu9jk0qMljkC4gSFvQA',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ipca_gm_willow_numpy.npy'],
'ada1df350d45cc877f08e12919993345')
}

def ipca_gm(feat1, feat2, A1, A2, n1, n2,
Expand All @@ -965,8 +970,8 @@ def ipca_gm(feat1, feat2, A1, A2, n1, n2,
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers, cross_iter)
if pretrain:
if pretrain in ipca_gm_pretrain_path.keys():
url, url_alter, md5 = ipca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'ipca_gm_{pretrain}_numpy.npy', url, md5, url_alter)
url, md5 = ipca_gm_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'ipca_gm_{pretrain}_numpy.npy', url, md5)
ipca_gm_numpy_dict = np.load(filename,allow_pickle=True)
for i in range(network.gnn_layer-1):
gnn_layer = network.dict['gnn_layer_{}'.format(i)]
Expand Down Expand Up @@ -1052,12 +1057,14 @@ def forward(self, feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2
return s

cie_pretrain_path = {
'voc':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1rP9sJY1fh493LLMWw-7RaeFAMHlbSs2D',
'https://www.dropbox.com/s/vxh2e1y5s1jidmk/cie_voc_numpy.npy?dl=1',
'9cbd55fa77d124b95052378643715bae'),
'willow':('https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1cMiXrSQjXZ9lDxeB6194z1-luyslVTR8',
'https://www.dropbox.com/s/c3i1nf3ruedm8vk/cie_willow_numpy.npy?dl=1',
'bd36e1bf314503c1f1482794e1648b18')
'voc':(['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_voc_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1rP9sJY1fh493LLMWw-7RaeFAMHlbSs2D',
'https://www.dropbox.com/s/vxh2e1y5s1jidmk/cie_voc_numpy.npy?dl=1'],
'9cbd55fa77d124b95052378643715bae'),
'willow':(['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_willow_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1cMiXrSQjXZ9lDxeB6194z1-luyslVTR8',
'https://www.dropbox.com/s/c3i1nf3ruedm8vk/cie_willow_numpy.npy?dl=1'],
'bd36e1bf314503c1f1482794e1648b18')
}

def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2,
Expand All @@ -1074,8 +1081,8 @@ def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2,
network = CIE_Net(in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers)
if pretrain:
if pretrain in cie_pretrain_path.keys():
url, url_alter, md5 = cie_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'cie_{pretrain}_numpy.npy', url, md5, url_alter)
url, md5 = cie_pretrain_path[pretrain]
filename = pygmtools.utils.download(f'cie_{pretrain}_numpy.npy', url, md5)
cie_numpy_dict = np.load(filename,allow_pickle=True)
for i in range(network.gnn_layer):
gnn_layer = network.dict['gnn_layer_{}'.format(i)]
Expand Down Expand Up @@ -1145,12 +1152,14 @@ def forward(self, K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau):
return _sinkhorn_func(s, n1, n2, dummy_row=True)

ngm_pretrain_path = {
'voc':('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ngm_voc_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1LY93fLCjH5vDcWsjZxGPmXmrYMF8HZIR',
'19cd48afab71b3277d2062624934702c'),
'willow':('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ngm_willow_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1iD8FHqahRsVV_H6o3ByB6nwBHU8sEgnt',
'31968e30c399845f34d80733d0118b8b')
'voc':(['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_voc_numpy.npy',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ngm_voc_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1LY93fLCjH5vDcWsjZxGPmXmrYMF8HZIR'],
'19cd48afab71b3277d2062624934702c'),
'willow':(['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_willow_numpy.npy',
'https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/numpy_backend/ngm_willow_numpy.npy',
'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1iD8FHqahRsVV_H6o3ByB6nwBHU8sEgnt'],
'31968e30c399845f34d80733d0118b8b')
}

def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau, network, return_network, pretrain):
Expand All @@ -1165,9 +1174,9 @@ def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau,
network = NGM_Net(gnn_channels, sk_emb)
if pretrain:
if pretrain in ngm_pretrain_path.keys():
url, url_alter, md5 = ngm_pretrain_path[pretrain]
url, md5 = ngm_pretrain_path[pretrain]
try:
filename = pygmtools.utils.download(f'ngm_{pretrain}_numpy.npy', url, md5, url_alter)
filename = pygmtools.utils.download(f'ngm_{pretrain}_numpy.npy', url, md5)
except:
filename = os.path.dirname(__file__) + f'/temp/ngm_{pretrain}_numpy.npy'
ngm_numpy_dict = np.load(filename, allow_pickle=True)
Expand Down
Loading

0 comments on commit f408a51

Please sign in to comment.