forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdefaultengine_matop_misc.go
131 lines (110 loc) · 3.06 KB
/
defaultengine_matop_misc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package tensor
import "github.com/pkg/errors"
func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) {
switch tt := t.(type) {
case DenseTensor:
return e.denseRepeat(tt, axis, repeats)
default:
return nil, errors.Errorf("NYI")
}
}
func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseTensor, err error) {
var newShape Shape
var size int
if newShape, repeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil {
return nil, errors.Wrap(err, "Unable to get repeated shape")
}
if axis == AllAxes {
axis = 0
}
d := recycledDense(t.Dtype(), newShape)
var outers int
if t.IsScalar() {
outers = 1
} else {
outers = ProdInts(t.Shape()[0:axis])
if outers == 0 {
outers = 1
}
}
var stride, newStride int
if newShape.IsVector() || t.IsVector() {
stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector
} else {
stride = t.ostrides()[axis]
}
if newShape.IsVector() {
newStride = 1
} else {
newStride = d.ostrides()[axis]
}
var destStart, srcStart int
for i := 0; i < outers; i++ {
for j := 0; j < size; j++ {
var tmp int
tmp = repeats[j]
for k := 0; k < tmp; k++ {
if srcStart >= t.len() || destStart+stride > d.len() {
break
}
copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len())
destStart += newStride
}
srcStart += stride
}
}
return d, nil
}
func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) {
switch tt := t.(type) {
case DenseTensor:
var denses []DenseTensor
if denses, err = tensorsToDenseTensors(others); err != nil {
return nil, errors.Wrap(err, "Concat failed")
}
return e.denseConcat(tt, axis, denses)
default:
return nil, errors.Errorf("NYI")
}
}
func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) {
ss := make([]Shape, len(Ts))
var err error
var isMasked bool
for i, T := range Ts {
ss[i] = T.Shape()
if mt, ok := T.(MaskedTensor); ok {
isMasked = isMasked || mt.IsMasked()
}
}
var newShape Shape
if newShape, err = a.Shape().Concat(axis, ss...); err != nil {
return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation")
}
retVal := recycledDense(a.Dtype(), newShape)
if isMasked {
retVal.makeMask()
}
all := make([]DenseTensor, len(Ts)+1)
all[0] = a
copy(all[1:], Ts)
// special case
var start, end int
for _, T := range all {
end += T.Shape()[axis]
slices := make([]Slice, axis+1)
slices[axis] = makeRS(start, end)
var v *Dense
if v, err = sliceDense(retVal, slices...); err != nil {
return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat")
}
if v.IsVector() && T.IsMatrix() && axis == 0 {
v.reshape(v.shape[0], 1)
}
if err = assignArray(v, T); err != nil {
return nil, errors.Wrap(err, "Unable to assignArray in denseConcat")
}
start = end
}
return retVal, nil
}