forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathRepeat.h
23 lines (19 loc) · 896 Bytes
/
Repeat.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#pragma once
#include <ATen/ATen.h>
namespace at { namespace native {
template <void compute(int64_t *, int64_t *, int64_t *, int64_t)>
static inline Tensor repeat_interleave_common(const Tensor &repeats) {
AT_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
AT_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor");
AT_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
Tensor repeats_ = repeats.contiguous();
Tensor cumsum = repeats.cumsum(0);
int64_t total = cumsum[-1].item<int64_t>();
Tensor result = at::empty({total}, repeats.options());
int64_t *repeat_ptr = repeats_.data<int64_t>();
int64_t *cumsum_ptr = cumsum.data<int64_t>();
int64_t *result_ptr = result.data<int64_t>();
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0));
return result;
}
}}