Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
cub change cub/pull/305
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 12, 2021
1 parent b4fe20e commit e1d0681
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
11 changes: 7 additions & 4 deletions thrust/system/cuda/detail/async/exclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,20 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
InitialValueType init,
BinaryOp op)
{
using InputValueT = cub::detail::InputValue<InitialValueType>;
using Dispatch32 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InitialValueType,
InputValueT,
thrust::detail::int32_t>;
using Dispatch64 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InitialValueType,
InputValueT,
thrust::detail::int64_t>;

InputValueT init_value(init);

auto const device_alloc = get_async_device_allocator(policy);
unique_eager_event ev;

Expand All @@ -101,7 +104,7 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
first,
out,
op,
init,
init_value,
n_fixed,
nullptr,
THRUST_DEBUG_SYNC_FLAG));
Expand Down Expand Up @@ -148,7 +151,7 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
first,
out,
op,
init,
init_value,
n_fixed,
user_raw_stream,
THRUST_DEBUG_SYNC_FLAG));
Expand Down
9 changes: 5 additions & 4 deletions thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
InitValueT init,
ScanOp scan_op)
{
using InputValueT = cub::detail::InputValue<InitValueT>;
using Dispatch32 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InitValueT,
InputValueT,
thrust::detail::int32_t>;
using Dispatch64 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InitValueT,
InputValueT,
thrust::detail::int64_t>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
Expand All @@ -163,7 +164,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
first,
result,
scan_op,
init,
InputValueT(init),
num_items_fixed,
stream,
THRUST_DEBUG_SYNC_FLAG));
Expand All @@ -187,7 +188,7 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
first,
result,
scan_op,
init,
InputValueT(init),
num_items_fixed,
stream,
THRUST_DEBUG_SYNC_FLAG));
Expand Down

0 comments on commit e1d0681

Please sign in to comment.