From 87e4e25db176b4dced5fb2952a2e674e5b8b0db8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 18 Sep 2017 19:50:53 -0700 Subject: [PATCH] Change Transform API Using DeviceContext, not Place to get stream --- paddle/platform/CMakeLists.txt | 2 +- paddle/platform/transform.h | 31 +++++++++++++++++++------------ paddle/platform/transform_test.cu | 13 +++++++++---- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 8b605e51c3f4e..daf519b91d623 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -24,4 +24,4 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) -nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place) +nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index 3ee4acd29660f..8eaab047fd4da 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" #include "paddle/platform/hostdevice.h" #include "paddle/platform/place.h" @@ -21,6 +22,7 @@ #include #include #ifdef __NVCC__ +#include #include #include "paddle/platform/details/device_ptr_cast.h" #endif @@ -28,34 +30,39 @@ namespace paddle { namespace platform { // Transform on host or device. It provides the same API in std library. -template -void Transform(Place place, InputIter first, InputIter last, OutputIter result, - UnaryOperation op) { +template +void Transform(const DeviceContext& context, InputIter first, InputIter last, + OutputIter result, UnaryOperation op) { + auto place = context.GetPlace(); if (is_cpu_place(place)) { std::transform(first, last, result, op); } else { #ifdef __NVCC__ + auto& ctx = reinterpret_cast(context); using namespace details; - thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result), - op); + thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first), + DevPtrCast(last), DevPtrCast(result), op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif } } -template -void Transform(Place place, InputIter1 first1, InputIter1 last1, - InputIter2 first2, OutputIter result, BinaryOperation op) { +template +void Transform(const DeviceContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + auto place = context.GetPlace(); if (is_cpu_place(place)) { std::transform(first1, last1, first2, result, op); } else { #ifdef __NVCC__ + auto& ctx = reinterpret_cast(context); using namespace details; - thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2), - DevPtrCast(result), op); + thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1), + DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result), + op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif diff --git a/paddle/platform/transform_test.cu b/paddle/platform/transform_test.cu index 600fed8f45077..b8a6200bb03c9 100644 --- a/paddle/platform/transform_test.cu +++ b/paddle/platform/transform_test.cu @@ -36,8 +36,9 @@ class Multiply { TEST(Transform, CPUUnary) { using namespace paddle::platform; + CPUDeviceContext ctx; float buf[4] = {0.1, 0.2, 0.3, 0.4}; - Transform(CPUPlace(), buf, buf + 4, buf, Scale(10)); + Transform(ctx, buf, buf + 4, buf, Scale(10)); for (int i = 0; i < 4; ++i) { ASSERT_NEAR(buf[i], static_cast(i + 1), 1e-5); } @@ -47,10 +48,12 @@ TEST(Transform, GPUUnary) { using namespace paddle::platform; using namespace paddle::memory; GPUPlace gpu0(0); + CUDADeviceContext ctx(gpu0); float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; float* gpu_buf = static_cast(Alloc(gpu0, sizeof(float) * 4)); Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf)); - Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + ctx.Wait(); Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf)); Free(gpu0, gpu_buf); for (int i = 0; i < 4; ++i) { @@ -62,7 +65,7 @@ TEST(Transform, CPUBinary) { using namespace paddle::platform; using namespace paddle::memory; int buf[4] = {1, 2, 3, 4}; - Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply()); + Transform(CPUDeviceContext(), buf, buf + 4, buf, buf, Multiply()); for (int i = 0; i < 4; ++i) { ASSERT_EQ((i + 1) * (i + 1), buf[i]); } @@ -73,9 +76,11 @@ TEST(Transform, GPUBinary) { using namespace paddle::memory; int buf[4] = {1, 2, 3, 4}; GPUPlace gpu0(0); + CUDADeviceContext ctx(gpu0); int* gpu_buf = static_cast(Alloc(gpu0, sizeof(buf))); Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf)); - Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + ctx.Wait(); Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf)); Free(gpu0, gpu_buf); for (int i = 0; i < 4; ++i) {