-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_process.py
44 lines (34 loc) · 1.36 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch
from torch.autograd import Variable
def make_variable(tensor, volatile=False,requires_grad=False):
"""Convert Tensor to Variable."""
if torch.cuda.is_available():
tensor = tensor.cuda()
return Variable(tensor, volatile=volatile, requires_grad = requires_grad)
def shuffle_aligned_list(data):
"""Shuffle arrays in a list by shuffling each array identically."""
num = data[0].shape[0]
p = np.random.permutation(num)
return [d[p] for d in data]
def batch_generator(data, batch_size, shuffle=False, wrap_around=False):
"""Generate batches of data.
Given a list of array-like objects, generate batches of a given
size by yielding a list of array-like objects corresponding to the
same slice of each input.
"""
if shuffle:
data = shuffle_aligned_list(data)
batch_count = 0
while True:
if (batch_count * batch_size + batch_size >= len(data[0])):
batch_count = 0
if shuffle:
data = shuffle_aligned_list(data)
start = batch_count * batch_size
end = start + batch_size
batch_count += 1
yield [d[start:end] for d in data]
# stop if reaches the end and the user doesn't want to wrap around
if (batch_count * batch_size + batch_size >= len(data[0]) and not wrap_around):
break