@@ -48,36 +48,42 @@ class feistel_bijection {
48
48
right_side_bits = total_bits - left_side_bits;
49
49
right_side_mask = (1ull << right_side_bits) - 1 ;
50
50
51
- for (std::uint64_t i = 0 ; i < num_rounds; i++) {
51
+ for (std::uint32_t i = 0 ; i < num_rounds; i++) {
52
52
key[i] = g ();
53
53
}
54
54
}
55
55
56
56
__host__ __device__ std::uint64_t nearest_power_of_two () const {
57
57
return 1ull << (left_side_bits + right_side_bits);
58
58
}
59
- __host__ __device__ std::uint64_t operator ()(const std::uint64_t val) const {
60
- // Extract the right and left sides of the input
61
- auto left = static_cast <std::uint32_t >(val >> right_side_bits);
62
- auto right = static_cast <std::uint32_t >(val & right_side_mask);
63
- round_state state = {left, right};
64
59
65
- for (std::uint64_t i = 0 ; i < num_rounds; i++) {
66
- state = do_round (state, i);
60
+ __host__ __device__ std::uint64_t operator ()(const std::uint64_t val) const {
61
+ std::uint32_t state[2 ] = { static_cast <std::uint32_t >( val >> right_side_bits ), static_cast <std::uint32_t >( val & right_side_mask ) };
62
+ for ( std::uint32_t i = 0 ; i < num_rounds; i++ )
63
+ {
64
+ std::uint32_t hi, lo;
65
+ constexpr std::uint64_t M0 = UINT64_C ( 0xD2B74407B1CE6E93 );
66
+ mulhilo ( M0, state[0 ], hi, lo );
67
+ lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1 ] >> left_side_bits;
68
+ state[0 ] = ( ( hi ^ key[i] ) ^ state[1 ] ) & left_side_mask;
69
+ state[1 ] = lo & right_side_mask;
67
70
}
68
-
69
- // Check we have the correct number of bits on each side
70
- assert ((state.left >> left_side_bits) == 0 );
71
- assert ((state.right >> right_side_bits) == 0 );
72
-
73
71
// Combine the left and right sides together to get result
74
- return state. left << right_side_bits | state. right ;
72
+ return static_cast <std:: uint64_t >( state[ 0 ] << right_side_bits) | static_cast <std:: uint64_t >( state[ 1 ]) ;
75
73
}
76
74
77
75
private:
76
+ // Perform 64 bit multiplication and save result in two 32 bit int
77
+ static __host__ __device__ void mulhilo ( std::uint64_t a, std::uint64_t b, std::uint32_t & hi, std::uint32_t & lo )
78
+ {
79
+ std::uint64_t product = a * b;
80
+ hi = static_cast <std::uint32_t >( product >> 32 );
81
+ lo = static_cast <std::uint32_t >( product );
82
+ }
83
+
78
84
// Find the nearest power of two
79
- __host__ __device__ std::uint64_t get_cipher_bits (std::uint64_t m) {
80
- if (m == 0 ) return 0 ;
85
+ static __host__ __device__ std::uint64_t get_cipher_bits (std::uint64_t m) {
86
+ if (m <= 16 ) return 4 ;
81
87
std::uint64_t i = 0 ;
82
88
m--;
83
89
while (m != 0 ) {
@@ -87,45 +93,12 @@ class feistel_bijection {
87
93
return i;
88
94
}
89
95
90
- // Equivalent to boost::hash_combine
91
- __host__ __device__
92
- std::size_t hash_combine (std::uint64_t lhs, std::uint64_t rhs) const {
93
- lhs ^= rhs + 0x9e3779b9 + (lhs << 6 ) + (lhs >> 2 );
94
- return lhs;
95
- }
96
-
97
- // Round function, a 'pseudorandom function' who's output is indistinguishable
98
- // from random for each key value input. This is not cryptographically secure
99
- // but sufficient for generating permutations.
100
- __host__ __device__ std::uint32_t round_function (std::uint64_t value,
101
- const std::uint64_t key_) const {
102
- std::uint64_t hash0 = thrust::random::taus88 (static_cast <std::uint32_t >(value))();
103
- std::uint64_t hash1 = thrust::random::ranlux48 (value)();
104
- return static_cast <std::uint32_t >(
105
- hash_combine (hash_combine (hash0, key_), hash1) & left_side_mask);
106
- }
107
-
108
- __host__ __device__ round_state do_round (const round_state state,
109
- const std::uint64_t round) const {
110
- const std::uint32_t new_left = state.right & left_side_mask;
111
- const std::uint32_t round_function_res =
112
- state.left ^ round_function (state.right , key[round ]);
113
- if (right_side_bits != left_side_bits) {
114
- // Upper bit of the old right becomes lower bit of new right if we have
115
- // odd length feistel
116
- const std::uint32_t new_right =
117
- (round_function_res << 1ull ) | state.right >> left_side_bits;
118
- return {new_left, new_right};
119
- }
120
- return {new_left, round_function_res};
121
- }
122
-
123
- static constexpr std::uint64_t num_rounds = 16 ;
96
+ static constexpr std::uint32_t num_rounds = 24 ;
124
97
std::uint64_t right_side_bits;
125
98
std::uint64_t left_side_bits;
126
99
std::uint64_t right_side_mask;
127
100
std::uint64_t left_side_mask;
128
- std::uint64_t key[num_rounds];
101
+ std::uint32_t key[num_rounds];
129
102
};
130
103
131
104
struct key_flag_tuple {
0 commit comments