Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update for cub::FutureValue PR (NVIDIA/cub#305) #1519

Merged
merged 1 commit into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions thrust/system/cuda/detail/async/exclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,20 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
InitialValueType init,
BinaryOp op)
{
using InputValueT = cub::detail::InputValue<InitialValueType>;
using Dispatch32 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InitialValueType,
InputValueT,
thrust::detail::int32_t>;
using Dispatch64 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InitialValueType,
InputValueT,
thrust::detail::int64_t>;

InputValueT init_value(init);

auto const device_alloc = get_async_device_allocator(policy);
unique_eager_event ev;

Expand All @@ -101,7 +104,7 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
first,
out,
op,
init,
init_value,
n_fixed,
nullptr,
THRUST_DEBUG_SYNC_FLAG));
Expand Down Expand Up @@ -148,7 +151,7 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
first,
out,
op,
init,
init_value,
n_fixed,
user_raw_stream,
THRUST_DEBUG_SYNC_FLAG));
Expand Down
9 changes: 5 additions & 4 deletions thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
InitValueT init,
ScanOp scan_op)
{
using InputValueT = cub::detail::InputValue<InitValueT>;
using Dispatch32 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InitValueT,
InputValueT,
thrust::detail::int32_t>;
using Dispatch64 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InitValueT,
InputValueT,
thrust::detail::int64_t>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
Expand All @@ -163,7 +164,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
first,
result,
scan_op,
init,
InputValueT(init),
num_items_fixed,
stream,
THRUST_DEBUG_SYNC_FLAG));
Expand All @@ -187,7 +188,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
first,
result,
scan_op,
init,
InputValueT(init),
num_items_fixed,
stream,
THRUST_DEBUG_SYNC_FLAG));
Expand Down