From 1daa53eebf57525155c96b1f7a487db6263ba31f Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 11 Oct 2021 21:06:31 -0700 Subject: [PATCH] Update for `FutureValue` in `DeviceScan` API (NVIDIA/cub#305) --- dependencies/cub | 2 +- thrust/system/cuda/detail/async/exclusive_scan.h | 11 +++++++---- thrust/system/cuda/detail/scan.h | 9 +++++---- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/dependencies/cub b/dependencies/cub index 703b10a92..5712619b4 160000 --- a/dependencies/cub +++ b/dependencies/cub @@ -1 +1 @@ -Subproject commit 703b10a92ad5e21c65446108696d4e3d428d5f04 +Subproject commit 5712619b4e7ef3fc62c98d328f4ffeb390adf6f8 diff --git a/thrust/system/cuda/detail/async/exclusive_scan.h b/thrust/system/cuda/detail/async/exclusive_scan.h index 8735f7419..377285411 100644 --- a/thrust/system/cuda/detail/async/exclusive_scan.h +++ b/thrust/system/cuda/detail/async/exclusive_scan.h @@ -74,17 +74,20 @@ async_exclusive_scan_n(execution_policy& policy, InitialValueType init, BinaryOp op) { + using InputValueT = cub::detail::InputValue; using Dispatch32 = cub::DispatchScan; using Dispatch64 = cub::DispatchScan; + InputValueT init_value(init); + auto const device_alloc = get_async_device_allocator(policy); unique_eager_event ev; @@ -101,7 +104,7 @@ async_exclusive_scan_n(execution_policy& policy, first, out, op, - init, + init_value, n_fixed, nullptr, THRUST_DEBUG_SYNC_FLAG)); @@ -148,7 +151,7 @@ async_exclusive_scan_n(execution_policy& policy, first, out, op, - init, + init_value, n_fixed, user_raw_stream, THRUST_DEBUG_SYNC_FLAG)); diff --git a/thrust/system/cuda/detail/scan.h b/thrust/system/cuda/detail/scan.h index 4f9628319..6e266a8db 100644 --- a/thrust/system/cuda/detail/scan.h +++ b/thrust/system/cuda/detail/scan.h @@ -137,15 +137,16 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy &poli InitValueT init, ScanOp scan_op) { + using InputValueT = cub::detail::InputValue; using Dispatch32 = cub::DispatchScan; using Dispatch64 = cub::DispatchScan; cudaStream_t stream = thrust::cuda_cub::stream(policy); @@ -163,7 +164,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy &poli first, result, scan_op, - init, + InputValueT(init), num_items_fixed, stream, THRUST_DEBUG_SYNC_FLAG)); @@ -187,7 +188,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy &poli first, result, scan_op, - init, + InputValueT(init), num_items_fixed, stream, THRUST_DEBUG_SYNC_FLAG));