-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
82 lines (66 loc) · 2.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import division
import argparse
import numpy as np
import torch
from triplet import train_triplet
from acai import train_acai
from umap import train_umap
from tsne import train_tsne
from vae import train_vae
from support_func import sanitize
from data import load_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa("--database", default="sift")
aa("--method", type=str, default="triplet")
group = parser.add_argument_group('Model hyperparameters')
aa("--dout", type=int, default=16,
help="output dimension")
group = parser.add_argument_group('Computation params')
aa("--seed", type=int, default=1234)
aa("--device", choices=["cuda", "cpu", "auto"], default="auto")
aa("--val_freq", type=int, default=10,
help="frequency of validation calls")
aa("--print_results", type=int, default=0)
aa("--batch_size", type=int, default=64)
aa("--epochs", type=int, default=40)
aa("--lr_schedule", type=str, default="0.1,0.1,0.05,0.01")
aa("--momentum", type=float, default=0.9)
args = parser.parse_args()
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
np.random.seed(args.seed)
torch.manual_seed(args.seed)
print(args)
results_file_name = "/home/shekhale/results/dim_red_zoo/" + args.database + "/train_results_" + args.method + ".txt"
if args.print_results > 0:
with open(results_file_name, "a") as rfile:
rfile.write("\n\n")
rfile.write("START TRAINING \n")
print ("load dataset %s" % args.database)
(_, xb, xq, _) = load_dataset(args.database, args.device, calc_gt=False, mnt=True)
base_size = xb.shape[0]
threshold = int(base_size * 0.1)
perm = np.random.permutation(base_size)
xv = xb[perm[:threshold]]
xt = xb[perm[threshold:]]
print(xb.shape, xt.shape, xv.shape, xq.shape)
xt = sanitize(xt)
xv = sanitize(xv)
xb = sanitize(xb)
xq = sanitize(xq)
if args.method == "triplet":
train_triplet(xt, xv, xq, args, results_file_name)
elif args.method == "acai":
train_acai(xt, xv, xq, args, results_file_name)
elif args.method == "umap":
train_umap(xt, xv, xq, args, results_file_name)
elif args.method == "tsne":
train_tsne(xt, xv, xq, args, results_file_name)
elif args.method == "vae":
train_vae(xt, xv, xq, args, results_file_name)
else:
print("Select an available method")