Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix top k op GPU code #5221

Merged
merged 4 commits into from
Oct 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions paddle/operators/top_k_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}

__device__ __forceinline__ void set(T value, int id) {
__device__ __forceinline__ void set(T value, int64_t id) {
v = value;
id = id;
}
Expand All @@ -48,7 +48,7 @@ struct Pair {
}

T v;
int id;
int64_t id;
};

template <typename T>
Expand Down Expand Up @@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int** topIds, int& beam, int& k,
int64_t** topIds, int& beam, int& k,
const int tid, const int warp) {
while (true) {
__syncthreads();
Expand Down Expand Up @@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
Expand Down Expand Up @@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {

T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int* indices_data = indices->mutable_data<int>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());

size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/top_k_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t k = static_cast<int>(ctx.Attr<int>("k"));

T* output_data = output->mutable_data<T>(ctx.GetPlace());
T* indices_data = indices->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());

auto eg_input = EigenMatrix<T>::From(*input);

Expand All @@ -66,7 +66,7 @@ class TopkKernel : public framework::OpKernel<T> {
});
for (size_t j = 0; j < k; j++) {
output_data[i * k + j] = vec[j].first;
indices_data[i * k + j] = vec[j].second;
indices_data[i * k + j] = int64_t(vec[j].second);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/v2/framework/tests/test_top_k_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def setUp(self):
k = 1
input = np.random.random((32, 84)).astype("float32")
output = np.ndarray((32, k))
indices = np.ndarray((32, k))
indices = np.ndarray((32, k)).astype("int64")

self.inputs = {'X': input}
self.attrs = {'k': k}
Expand All @@ -32,7 +32,7 @@ def setUp(self):
input = np.random.random((32, 2, 84)).astype("float32")
input_flat_2d = input.reshape(64, 84)
output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int")
indices = np.ndarray((64, k)).astype("int64")

# FIXME: should use 'X': input for a 3d input
self.inputs = {'X': input_flat_2d}
Expand Down