forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ExpandUtils.cpp
212 lines (185 loc) · 7 KB
/
ExpandUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#include <ATen/ExpandUtils.h>
#include <c10/util/irange.h>
namespace at {
namespace {
// NOTE: are_expandable did a similar check, please keep them sync if change is needed
template <typename Container>
Container infer_size_impl(IntArrayRef a, IntArrayRef b) {
size_t dimsA = a.size();
size_t dimsB = b.size();
size_t ndim = dimsA > dimsB ? dimsA : dimsB;
Container expandedSizes(ndim);
// Use ptrdiff_t to ensure signed comparison.
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i;
ptrdiff_t dimA = dimsA - 1 - offset;
ptrdiff_t dimB = dimsB - 1 - offset;
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1;
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1;
TORCH_CHECK(
sizeA == sizeB || sizeA == 1 || sizeB == 1,
"The size of tensor a (", sizeA,
") must match the size of tensor b (", sizeB,
") at non-singleton dimension ", i);
// 1s map to the other size (even 0).
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
}
return expandedSizes;
}
}
std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
return infer_size_impl<std::vector<int64_t>>(a, b);
}
DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
return infer_size_impl<DimVector>(a, b);
}
template<typename Container>
C10_ALWAYS_INLINE InferExpandGeometryResult<Container> inferExpandGeometryImpl(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
int64_t ndim = sizes.size();
int64_t tensor_dim = tensor_sizes.size();
if (tensor_dim == 0) {
return InferExpandGeometryResult<Container>(sizes, ndim);
}
InferExpandGeometryResult<Container> result(ndim);
auto& expandedSizes = result.sizes;
auto& expandedStrides = result.strides;
// create a new geometry for the tensors
for (int64_t i = ndim - 1; i >= 0; --i) {
int64_t offset = ndim - 1 - i;
int64_t dim = tensor_dim - 1 - offset;
int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1;
int64_t stride = (dim >= 0) ? tensor_strides[dim]
: expandedSizes[i + 1] * expandedStrides[i + 1];
int64_t targetSize = sizes[i];
if (targetSize == -1) {
TORCH_CHECK(
dim >= 0,
"The expanded size of the tensor (",
targetSize,
") isn't allowed in a leading, non-existing dimension ",
i);
targetSize = size;
}
if (size != targetSize) {
TORCH_CHECK(
size == 1,
"The expanded size of the tensor (",
targetSize,
") must match the existing size (",
size,
") at non-singleton dimension ",
i,
". Target sizes: ",
sizes,
". Tensor sizes: ",
tensor_sizes);
size = targetSize;
stride = 0;
}
expandedSizes[i] = size;
expandedStrides[i] = stride;
}
return result;
}
std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
auto result = inferExpandGeometryImpl<std::vector<int64_t>>(
tensor_sizes, tensor_strides, sizes);
return std::make_tuple(std::move(result.sizes), std::move(result.strides));
}
InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
return inferExpandGeometryImpl<DimVector>(
tensor_sizes, tensor_strides, sizes);
}
// This function returns a dense and non-overlapping strides, which keeps the same layout permutation
// as the input `tensor_strides`, computed based on the input `tensor_sizes`.
// Note:
// 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping,
// If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`.
// However, this function won't check whether inputs are dense or overlapping, so the whole function will
// still be executed even the inputs are already dense and non-overlapping, this will cause slowness.
//
// Please verify whether the inputs are non-dense or overlapping before calling this function if possible,
// if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()`
//
// 2. The strides propagation rule that is used in this function is exactily the same as what is being used in
// TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details
std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) {
TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(),
"Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size());
size_t ndim = tensor_sizes.size();
if (ndim == 0) {
return {};
}
if (ndim == 1) {
return {1};
}
std::vector<int64_t> perm(ndim);
// initialize perm with n-1, n-2, ..., 1, 0
std::iota(perm.rbegin(), perm.rend(), 0);
// The following sorting algorithm has exactly the same behavior as TensorIterator
// This is to make sure we have the same stride propagation everywhere.
// return -1 if dim0 should come before dim1
// return 1 if dim0 should come after dim1
// return 0 if comparison is ambiguous
auto should_swap = [&](size_t dim0, size_t dim1) {
int64_t stride0 = tensor_strides[dim0];
int64_t stride1 = tensor_strides[dim1];
// if any stride is 0, treat it as ambiguous comparison to
// keep the same behavior as TensorIterator
if (stride0 == 0 || stride1 == 0) {
return 0;
}
if (stride0 < stride1) {
return -1;
}
if (stride0 > stride1) {
return 1;
}
// for equal strides, the dimension with smaller size goes front
if (tensor_sizes[dim0] > tensor_sizes[dim1]) {
return 1;
}
return 0;
};
// Insertion sort (stable) indices in `perm` based on input tensor's stride and shape,
// all dimensions with 0 stride won't move. This is the same behavior as TensorIterator.
// eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm`
// is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2)
for (const auto i : c10::irange(1, ndim)) {
auto dim1 = i;
for (const auto j : c10::irange(1, i + 1)) {
auto dim0 = i - j;
int comparison = should_swap(perm[dim0], perm[dim1]);
if (comparison > 0) {
std::swap(perm[dim0], perm[dim1]);
dim1 = dim0;
}
else if (comparison < 0) {
break;
}
}
}
// compute output strides which preserves the input tensor's memory layout
std::vector<int64_t> out_strides(ndim);
int64_t curr_stride = 1;
for (size_t i = 0; i < ndim; ++i) {
int64_t idx = perm[i];
out_strides[idx] = curr_stride;
// Note: for size 0, we simply treated it as 1, it really doesn't matter here
// since the total number of element is 0.
if (tensor_sizes[idx] > 1) {
curr_stride *= tensor_sizes[idx];
}
}
return out_strides;
}
} // namespace at