Skip to content

Commit

Permalink
fix seed for class_center_sample using paddle.seed (#38248)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 18, 2021
1 parent 6418bc7 commit 59be8e0
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions paddle/fluid/operators/class_center_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,26 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
ctx.cuda_device_context().stream())));

// step 5: random sample negative class center
uint64_t seed_data;
uint64_t increment;
int vec_size = VectorizedSize<T>(cub_sort_keys_ptr);
int increment = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
if (!fix_seed) {
auto offset = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
auto gen_cuda = framework::GetDefaultCUDAGenerator(rank);
if (gen_cuda->GetIsInitPy() && (!fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
increment = seed_offset.second;
} else {
std::random_device rnd;
seed = rnd();
seed_data = fix_seed ? seed + rank : rnd();
increment = offset;
}
RandomSampleClassCenter<T><<<NumBlocks(num_classes), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
num_classes, seed + rank, increment, num_classes, cub_sort_keys_ptr);
num_classes, seed_data, increment, num_classes, cub_sort_keys_ptr);

// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
Expand Down

0 comments on commit 59be8e0

Please sign in to comment.