Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit efc5fc8

Browse files
authored
Merge pull request #1566 from djns99/PhiloxShuffle
Updated thrust shuffle to use improved bijective function
2 parents f80c77b + 616f17d commit efc5fc8

File tree

3 files changed

+33
-64
lines changed

3 files changed

+33
-64
lines changed

internal/benchmark/bench.cu

-2
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,13 @@ void run_core_primitives_experiments_for_type()
992992
, RegularTrials
993993
>::run_experiment();
994994

995-
#if THRUST_CPP_DIALECT >= 2011
996995
experiment_driver<
997996
shuffle_tester
998997
, ElementMetaType
999998
, Elements / sizeof(typename ElementMetaType::type)
1000999
, BaselineTrials
10011000
, RegularTrials
10021001
>::run_experiment();
1003-
#endif
10041002
}
10051003

10061004
///////////////////////////////////////////////////////////////////////////////

testing/shuffle.cu

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <thrust/detail/config.h>
22

3-
#if THRUST_CPP_DIALECT >= 2011
43
#include <map>
54
#include <limits>
65
#include <thrust/random.h>
@@ -383,7 +382,7 @@ void TestFunctionIsBijection(size_t m) {
383382
thrust::system::detail::generic::feistel_bijection host_f(m, host_g);
384383
thrust::system::detail::generic::feistel_bijection device_f(m, device_g);
385384

386-
if (host_f.nearest_power_of_two() >= std::numeric_limits<T>::max() || m == 0) {
385+
if (static_cast<double>(host_f.nearest_power_of_two()) >= static_cast<double>(std::numeric_limits<T>::max()) || m == 0) {
387386
return;
388387
}
389388

@@ -410,17 +409,17 @@ DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection);
410409
void TestBijectionLength() {
411410
thrust::default_random_engine g(0xD5);
412411

413-
uint64_t m = 3;
412+
uint64_t m = 31;
414413
thrust::system::detail::generic::feistel_bijection f(m, g);
415-
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(4));
414+
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));
416415

417-
m = 2;
416+
m = 32;
418417
f = thrust::system::detail::generic::feistel_bijection(m, g);
419-
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(2));
418+
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));
420419

421-
m = 0;
420+
m = 1;
422421
f = thrust::system::detail::generic::feistel_bijection(m, g);
423-
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(1));
422+
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(16));
424423
}
425424
DECLARE_UNITTEST(TestBijectionLength);
426425

@@ -515,7 +514,7 @@ void TestShuffleEvenSpacingBetweenOccurances() {
515514
thrust::host_vector<T> h_results;
516515
Vector sequence(shuffle_size);
517516
thrust::sequence(sequence.begin(), sequence.end(), 0);
518-
thrust::default_random_engine g(0xD5);
517+
thrust::default_random_engine g(0xD6);
519518
for (auto i = 0ull; i < num_samples; i++) {
520519
thrust::shuffle(sequence.begin(), sequence.end(), g);
521520
thrust::host_vector<T> tmp(sequence.begin(), sequence.end());
@@ -561,7 +560,7 @@ void TestShuffleEvenDistribution() {
561560
const uint64_t shuffle_sizes[] = {10, 100, 500};
562561
thrust::default_random_engine g(0xD5);
563562
for (auto shuffle_size : shuffle_sizes) {
564-
if(shuffle_size > std::numeric_limits<T>::max())
563+
if(shuffle_size > (uint64_t)std::numeric_limits<T>::max())
565564
continue;
566565
const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200;
567566

@@ -601,4 +600,3 @@ void TestShuffleEvenDistribution() {
601600
}
602601
}
603602
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution);
604-
#endif

thrust/system/detail/generic/shuffle.inl

+24-51
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,42 @@ class feistel_bijection {
4848
right_side_bits = total_bits - left_side_bits;
4949
right_side_mask = (1ull << right_side_bits) - 1;
5050

51-
for (std::uint64_t i = 0; i < num_rounds; i++) {
51+
for (std::uint32_t i = 0; i < num_rounds; i++) {
5252
key[i] = g();
5353
}
5454
}
5555

5656
__host__ __device__ std::uint64_t nearest_power_of_two() const {
5757
return 1ull << (left_side_bits + right_side_bits);
5858
}
59-
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
60-
// Extract the right and left sides of the input
61-
auto left = static_cast<std::uint32_t>(val >> right_side_bits);
62-
auto right = static_cast<std::uint32_t>(val & right_side_mask);
63-
round_state state = {left, right};
6459

65-
for (std::uint64_t i = 0; i < num_rounds; i++) {
66-
state = do_round(state, i);
60+
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
61+
std::uint32_t state[2] = { static_cast<std::uint32_t>( val >> right_side_bits ), static_cast<std::uint32_t>( val & right_side_mask ) };
62+
for( std::uint32_t i = 0; i < num_rounds; i++ )
63+
{
64+
std::uint32_t hi, lo;
65+
constexpr std::uint64_t M0 = UINT64_C( 0xD2B74407B1CE6E93 );
66+
mulhilo( M0, state[0], hi, lo );
67+
lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1] >> left_side_bits;
68+
state[0] = ( ( hi ^ key[i] ) ^ state[1] ) & left_side_mask;
69+
state[1] = lo & right_side_mask;
6770
}
68-
69-
// Check we have the correct number of bits on each side
70-
assert((state.left >> left_side_bits) == 0);
71-
assert((state.right >> right_side_bits) == 0);
72-
7371
// Combine the left and right sides together to get result
74-
return state.left << right_side_bits | state.right;
72+
return static_cast<std::uint64_t>(state[0] << right_side_bits) | static_cast<std::uint64_t>(state[1]);
7573
}
7674

7775
private:
76+
// Perform 64 bit multiplication and save result in two 32 bit int
77+
static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo )
78+
{
79+
std::uint64_t product = a * b;
80+
hi = static_cast<std::uint32_t>( product >> 32 );
81+
lo = static_cast<std::uint32_t>( product );
82+
}
83+
7884
// Find the nearest power of two
79-
__host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
80-
if (m == 0) return 0;
85+
static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
86+
if (m <= 16) return 4;
8187
std::uint64_t i = 0;
8288
m--;
8389
while (m != 0) {
@@ -87,45 +93,12 @@ class feistel_bijection {
8793
return i;
8894
}
8995

90-
// Equivalent to boost::hash_combine
91-
__host__ __device__
92-
std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const {
93-
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
94-
return lhs;
95-
}
96-
97-
// Round function, a 'pseudorandom function' who's output is indistinguishable
98-
// from random for each key value input. This is not cryptographically secure
99-
// but sufficient for generating permutations.
100-
__host__ __device__ std::uint32_t round_function(std::uint64_t value,
101-
const std::uint64_t key_) const {
102-
std::uint64_t hash0 = thrust::random::taus88(static_cast<std::uint32_t>(value))();
103-
std::uint64_t hash1 = thrust::random::ranlux48(value)();
104-
return static_cast<std::uint32_t>(
105-
hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask);
106-
}
107-
108-
__host__ __device__ round_state do_round(const round_state state,
109-
const std::uint64_t round) const {
110-
const std::uint32_t new_left = state.right & left_side_mask;
111-
const std::uint32_t round_function_res =
112-
state.left ^ round_function(state.right, key[round]);
113-
if (right_side_bits != left_side_bits) {
114-
// Upper bit of the old right becomes lower bit of new right if we have
115-
// odd length feistel
116-
const std::uint32_t new_right =
117-
(round_function_res << 1ull) | state.right >> left_side_bits;
118-
return {new_left, new_right};
119-
}
120-
return {new_left, round_function_res};
121-
}
122-
123-
static constexpr std::uint64_t num_rounds = 16;
96+
static constexpr std::uint32_t num_rounds = 24;
12497
std::uint64_t right_side_bits;
12598
std::uint64_t left_side_bits;
12699
std::uint64_t right_side_mask;
127100
std::uint64_t left_side_mask;
128-
std::uint64_t key[num_rounds];
101+
std::uint32_t key[num_rounds];
129102
};
130103

131104
struct key_flag_tuple {

0 commit comments

Comments
 (0)