-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathredis_dataloader_test.py
63 lines (52 loc) · 1.66 KB
/
redis_dataloader_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import gc
import random
import pprint
from six.moves import range
from time import gmtime, strftime
from timeit import default_timer as timer
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import options
from misc.dataloader import VQAPoolDataset
import misc.utilities as utils
from misc.eval_questioner import evalQBot
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import numpy as np
import pdb
params = options.readCommandLine()
data_params = options.data_params(params)
data_params['useRedis'] = 1
# Seed rng for reproducibility
random.seed(params['randomSeed'])
torch.manual_seed(params['randomSeed'])
#if params['useGPU']:
# torch.cuda.manual_seed_all(params['randomSeed'])
splits = ['train', 'val1', 'val2']
# Setup dataloader
dataset = VQAPoolDataset(data_params, splits)
dataset.split = 'train'
dataloader = DataLoader(
dataset,
batch_size=params['batchSize'],
shuffle=False,
num_workers=3,
pin_memory=False)
def batch_iter(dataloader):
for epochId in range(params['numEpochs']):
for idx, batch in enumerate(dataloader):
yield epochId, idx, batch
numIterPerEpoch = len(dataloader)
start_t = timer()
for epochId, idx, batch in batch_iter(dataloader):
iterId = idx + (epochId * numIterPerEpoch)
epochFrac = iterId / numIterPerEpoch
end_t = timer() # Keeping track of iteration(s) time
timeStamp = strftime('%a %d %b %y %X', gmtime())
log_line = f'[{timeStamp}][Ep: {epochFrac:.2f}][Iter: {iterId}][Time: {end_t - start_t:5.2f}s]'
print(log_line)
start_t = end_t