Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Updated thrust shuffle to use improved bijective function #1566

Merged
merged 4 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions internal/benchmark/bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -992,15 +992,13 @@ void run_core_primitives_experiments_for_type()
, RegularTrials
>::run_experiment();

#if THRUST_CPP_DIALECT >= 2011
experiment_driver<
shuffle_tester
, ElementMetaType
, Elements / sizeof(typename ElementMetaType::type)
, BaselineTrials
, RegularTrials
>::run_experiment();
#endif
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
20 changes: 9 additions & 11 deletions testing/shuffle.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2011
#include <map>
#include <limits>
#include <thrust/random.h>
Expand Down Expand Up @@ -383,7 +382,7 @@ void TestFunctionIsBijection(size_t m) {
thrust::system::detail::generic::feistel_bijection host_f(m, host_g);
thrust::system::detail::generic::feistel_bijection device_f(m, device_g);

if (host_f.nearest_power_of_two() >= std::numeric_limits<T>::max() || m == 0) {
if (static_cast<double>(host_f.nearest_power_of_two()) >= static_cast<double>(std::numeric_limits<T>::max()) || m == 0) {
return;
}

Expand All @@ -410,17 +409,17 @@ DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection);
void TestBijectionLength() {
thrust::default_random_engine g(0xD5);

uint64_t m = 3;
uint64_t m = 31;
thrust::system::detail::generic::feistel_bijection f(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(4));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));

m = 2;
m = 32;
f = thrust::system::detail::generic::feistel_bijection(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(2));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));

m = 0;
m = 1;
f = thrust::system::detail::generic::feistel_bijection(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(1));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(16));
}
DECLARE_UNITTEST(TestBijectionLength);

Expand Down Expand Up @@ -515,7 +514,7 @@ void TestShuffleEvenSpacingBetweenOccurances() {
thrust::host_vector<T> h_results;
Vector sequence(shuffle_size);
thrust::sequence(sequence.begin(), sequence.end(), 0);
thrust::default_random_engine g(0xD5);
thrust::default_random_engine g(0xD6);
for (auto i = 0ull; i < num_samples; i++) {
thrust::shuffle(sequence.begin(), sequence.end(), g);
thrust::host_vector<T> tmp(sequence.begin(), sequence.end());
Expand Down Expand Up @@ -561,7 +560,7 @@ void TestShuffleEvenDistribution() {
const uint64_t shuffle_sizes[] = {10, 100, 500};
thrust::default_random_engine g(0xD5);
for (auto shuffle_size : shuffle_sizes) {
if(shuffle_size > std::numeric_limits<T>::max())
if(shuffle_size > (uint64_t)std::numeric_limits<T>::max())
continue;
const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200;

Expand Down Expand Up @@ -601,4 +600,3 @@ void TestShuffleEvenDistribution() {
}
}
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution);
#endif
75 changes: 24 additions & 51 deletions thrust/system/detail/generic/shuffle.inl
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,42 @@ class feistel_bijection {
right_side_bits = total_bits - left_side_bits;
right_side_mask = (1ull << right_side_bits) - 1;

for (std::uint64_t i = 0; i < num_rounds; i++) {
for (std::uint32_t i = 0; i < num_rounds; i++) {
key[i] = g();
}
}

__host__ __device__ std::uint64_t nearest_power_of_two() const {
return 1ull << (left_side_bits + right_side_bits);
}
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
// Extract the right and left sides of the input
auto left = static_cast<std::uint32_t>(val >> right_side_bits);
auto right = static_cast<std::uint32_t>(val & right_side_mask);
round_state state = {left, right};

for (std::uint64_t i = 0; i < num_rounds; i++) {
state = do_round(state, i);
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
std::uint32_t state[2] = { static_cast<std::uint32_t>( val >> right_side_bits ), static_cast<std::uint32_t>( val & right_side_mask ) };
for( std::uint32_t i = 0; i < num_rounds; i++ )
{
std::uint32_t hi, lo;
constexpr std::uint64_t M0 = UINT64_C( 0xD2B74407B1CE6E93 );
mulhilo( M0, state[0], hi, lo );
lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1] >> left_side_bits;
state[0] = ( ( hi ^ key[i] ) ^ state[1] ) & left_side_mask;
state[1] = lo & right_side_mask;
}

// Check we have the correct number of bits on each side
assert((state.left >> left_side_bits) == 0);
assert((state.right >> right_side_bits) == 0);

// Combine the left and right sides together to get result
return state.left << right_side_bits | state.right;
return static_cast<std::uint64_t>(state[0] << right_side_bits) | static_cast<std::uint64_t>(state[1]);
}

private:
// Perform 64 bit multiplication and save result in two 32 bit int
constexpr static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo )
{
std::uint64_t product = a * b;
hi = static_cast<std::uint32_t>( product >> 32 );
lo = static_cast<std::uint32_t>( product );
}

// Find the nearest power of two
__host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
if (m == 0) return 0;
constexpr static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
if (m <= 16) return 4;
std::uint64_t i = 0;
m--;
while (m != 0) {
Expand All @@ -87,45 +93,12 @@ class feistel_bijection {
return i;
}

// Equivalent to boost::hash_combine
__host__ __device__
std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}

// Round function, a 'pseudorandom function' who's output is indistinguishable
// from random for each key value input. This is not cryptographically secure
// but sufficient for generating permutations.
__host__ __device__ std::uint32_t round_function(std::uint64_t value,
const std::uint64_t key_) const {
std::uint64_t hash0 = thrust::random::taus88(static_cast<std::uint32_t>(value))();
std::uint64_t hash1 = thrust::random::ranlux48(value)();
return static_cast<std::uint32_t>(
hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask);
}

__host__ __device__ round_state do_round(const round_state state,
const std::uint64_t round) const {
const std::uint32_t new_left = state.right & left_side_mask;
const std::uint32_t round_function_res =
state.left ^ round_function(state.right, key[round]);
if (right_side_bits != left_side_bits) {
// Upper bit of the old right becomes lower bit of new right if we have
// odd length feistel
const std::uint32_t new_right =
(round_function_res << 1ull) | state.right >> left_side_bits;
return {new_left, new_right};
}
return {new_left, round_function_res};
}

static constexpr std::uint64_t num_rounds = 16;
static constexpr std::uint32_t num_rounds = 24;
std::uint64_t right_side_bits;
std::uint64_t left_side_bits;
std::uint64_t right_side_mask;
std::uint64_t left_side_mask;
std::uint64_t key[num_rounds];
std::uint32_t key[num_rounds];
};

struct key_flag_tuple {
Expand Down