-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlswift_playground.py
90 lines (86 loc) · 4.02 KB
/
mlswift_playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from swiftclient.service import SwiftService, SwiftPostObject, SwiftError
import pickle
import torchvision
from application_layer.utils import get_model
from time import time
import torch
import argparse
parser = argparse.ArgumentParser(description='Do ML computation in Swift')
parser.add_argument('--dataset', default='mnist', type=str, help='dataset to be used')
parser.add_argument('--model', default='convnet', type=str, help='model to be used')
parser.add_argument('--task', default='inference', type=str, help='ML task (inference or training)')
parser.add_argument('--batch_size', default=100, type=int, help='batch size for dataloader')
parser.add_argument('--num_epochs', default=10, type=int, help='number of epochs for training')
parser.add_argument('--freeze_idx', default=1000, type=int, help='the last layer of feature extraction; training starts from the subsequent layer')
parser.add_argument('--gpu_id', default=0, type=int, help='which GPU on the server to run on; added for load balancing')
args = parser.parse_args()
dataset = args.dataset
model = args.model
task = args.task
batch_size = args.batch_size
num_epochs = args.num_epochs
freeze_idx = args.freeze_idx
gpu_id = args.gpu_id
print(args)
parent_dirs = {'imagenet':'val', 'mnist':'mnist', 'cifar10':'cifar-10-batches-py'}
parent_dir = parent_dirs[dataset]
objs_invoke = {'imagenet':'{}/ILSVRC2012_val_00000001.JPEG'.format(parent_dir),
'mnist_training':'{}/train-images-idx3-ubyte'.format(parent_dir),
'mnist_inference':'{}/t10k-images-idx3-ubyte'.format(parent_dir),
'cifar10':'{}/test_batch'.format(parent_dir)}
try:
obj = objs_invoke[dataset]
except: #This should be mnist!
obj = objs_invoke[dataset+"_"+task]
objects = [obj]
swift = SwiftService()
step = 10000
#If it is a training task, I do not want to invoke post multiple times
#If it is an inference task with a small dataset (mnist or cifar10), we can actually do it in one go
#In both cases, set the size to step so that for loop is entered once only
if task == 'training' or dataset == 'mnist' or dataset == 'cifar10':
dataset_size = step
else:
dataset_size = 50000 #Imagenet has 50K images for now
start_time = time()
#post_objects = []
for start in range(0,dataset_size,step):
end = start+step
#ask to do inference for images [strat:end] from the test batch
print("{} for data [{}:{}]".format(task,start,end))
opts = {"meta": {"Ml-Task:{}".format(task),
"dataset:{}".format(dataset),"model:{}".format(model),
"Batch-Size:{}".format(batch_size),"Num-Epochs:{}".format(num_epochs),
"Lossfn:cross-entropy","Optimizer:sgd",
"start:{}".format(start),"end:{}".format(end),
"Split-Idx:{}".format(freeze_idx),
"GPU-ID:{}".format(gpu_id)
},
"header": {"Parent-Dir:{}".format(parent_dir)}}
post_objects = [SwiftPostObject(o,opts) for o in objects]
for post_res in swift.post(
container=dataset,
objects=post_objects):
if post_res['success']:
print("Object '%s' POST success" % post_res['object'])
print("Request took {} seconds".format(time()-start_time))
body = post_res['result']
if task == 'inference':
inf_res = pickle.loads(body)
# if model.startswith("my"): #new path for transfer learning toy example....this is not really inference
# model = get_model(model, dataset) #use CPU only in this script...no need for GPU now
# final_res = []
# for int_res in inf_res:
# inputs = torch.from_numpy(int_res)
# logits = model(inputs, split_idx,100) #continue the inference process here
# final_res.extend(logits.max(1)[1])
# print("Split inference results length: {}".format(len(final_res)))
# else:
print("{} result length: {}".format(task, len(inf_res)))
elif task == 'training':
model_dict = pickle.loads(body)
model = get_model(model, dataset)
model.load_state_dict(model_dict)
else:
print("Object '%s' POST failed" % post_res['object'])
print("The whole process took {} seconds".format(time()-start_time))