-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdynamicarray.py
128 lines (115 loc) · 3.37 KB
/
dynamicarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
class DynArr():
def __init__(self, d=None, resizemult=2, initsize=0, initbuffersize=10, device='cpu', **kw):
'''
DynArr allows for the "expanding" arrays in the Pi-Net.
Each training step uses gradient *concatenation* instead of accumulation,
meaning all matrices get larger.
DynArr is a simple class to allow this to happen.
Note that torch is really not a fan of this whole process, hence the custom backwards and
'''
self.d = d
self.device = device
self.resizemult = resizemult
if d is not None:
self.arr = torch.zeros([initbuffersize, d], device=device)
else:
self.arr = torch.zeros([initbuffersize], device=device)
self.size = initsize
def isempty(self):
return self.size == 0
def _cat(self, arr):
size = arr.shape[0]
if self.size + size > len(self.arr):
newsize = int(self.resizemult * self.size + size)
if self.d is not None:
assert arr.shape[1] == self.d
self.arr.resize_(newsize, self.d)
else:
assert len(arr.shape) == 1
self.arr.resize_(newsize)
self.arr[self.size:self.size+size] = arr
self.size += size
def cat(self, *arrs):
for arr in arrs:
self._cat(arr)
def checkpoint(self):
self._checkpoint = self.size
def restore(self):
self.size = self._checkpoint
@property
def a(self):
return self.arr[:self.size]
def cuda(self):
self.arr = self.arr.cuda()
return self
def cpu(self):
self.arr = self.arr.cpu()
return self
def half(self):
self.arr = self.arr.half()
return self
def float(self):
self.arr = self.arr.float()
return self
class CycArr():
def __init__(self, d=None, maxsize=10000, initsize=0, device='cpu', **kw):
'''
Used for some testing (deprecated).
Instead of appending to end, write cyclically.
'''
assert initsize <= maxsize
self.size = initsize
if initsize == maxsize:
self.end = 0
else:
self.end = None
self.d = d
self.maxsize = maxsize
self.device = device
if d is not None:
self.arr = torch.zeros([maxsize, d], device=device)
else:
self.arr = torch.zeros([maxsize], device=device)
@property
def a(self):
if self.size == self.maxsize:
return self.arr
return self.arr[:self.size]
def cuda(self):
self.arr = self.arr.cuda()
return self
def half(self):
self.arr = self.arr.half()
return self
def isempty(self):
return self.size == 0
def _cat(self, arr):
size = arr.shape[0]
if self.size == self.maxsize:
# cyclic writing
if self.end + size < self.maxsize:
self.arr[self.end:self.end+size] = arr
self.end += size
else:
p1size = self.maxsize - self.end
p2size = size - p1size
self.arr[self.end:] = arr[:p1size]
self.arr[:p2size] = arr[p1size:]
self.end = p2size
elif self.size + size >= self.maxsize:
assert size < self.maxsize
# writing at the end and spill over to beginning
p1size = self.maxsize - self.size
p2size = size - p1size
self.arr[self.size:] = arr[:p1size]
self.arr[:p2size] = arr[p1size:]
self.end = p2size
self.size = self.maxsize
else:
# noncyclic writing
self.arr[self.size:self.size+size] = arr
self.size += size
def cat(self, *arrs):
for arr in arrs:
self._cat(arr)