-
Notifications
You must be signed in to change notification settings - Fork 1
/
batch.py
33 lines (25 loc) · 1.03 KB
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy
import numpy.random
from itertools import chain
class Batch:
def __init__(self, arrays, input_size, output_size):
self.size = len(arrays)
self.buffers = arrays
self.input_size = input_size
self.output_size = output_size
def sub_batch(self,from_, to):
return Batch(self.buffers[from_:to],self.input_size,self.output_size)
def split(self, parts):
part_size = self.size // parts
return[ self.sub_batch(xz*part_size, (xz+1)*part_size) for xz in xrange(parts) ]
def sample(self, count):
count = min(count, self.size)
randomized = numpy.random.permutation(self.size)
return Batch(numpy.take(self.buffers,randomized[:count],axis=0),self.input_size,self.output_size)
def sample_split(self, samples, parts):
return self.sample(samples).split(parts)
def __len__(self):
return self.size
def expand(self, multiplier):
self.buffers = list(chain.from_iterable([self.buffers]*multiplier))
self.size *= multiplier