From 0f28706acf7822f7268b00e5d7b142e036a1ba41 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Mon, 24 Jul 2023 02:29:38 +0000 Subject: [PATCH] [PHI CAPI] support get & set random seed --- paddle/phi/capi/include/c_device_context.h | 10 ++++++++ paddle/phi/capi/include/wrapper_base.h | 20 ++++++++++++++++ paddle/phi/capi/lib/c_device_context.cc | 28 ++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/paddle/phi/capi/include/c_device_context.h b/paddle/phi/capi/include/c_device_context.h index 8612eb176b56d8..464af1185642f1 100644 --- a/paddle/phi/capi/include/c_device_context.h +++ b/paddle/phi/capi/include/c_device_context.h @@ -36,6 +36,16 @@ void *PD_DeviceContextAllocateTensor(const PD_DeviceContext *ctx, PD_DataType dtype, PD_Status *status); +void PD_DeviceContextSetSeed(const PD_DeviceContext *ctx, + uint64_t seed, + PD_Status *status); + +uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext *ctx, + PD_Status *status); + +uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext *ctx, + PD_Status *status); + #ifdef __cplusplus } // extern "C" #endif diff --git a/paddle/phi/capi/include/wrapper_base.h b/paddle/phi/capi/include/wrapper_base.h index 1379bf108880f9..9924f4d5efb6ba 100644 --- a/paddle/phi/capi/include/wrapper_base.h +++ b/paddle/phi/capi/include/wrapper_base.h @@ -298,6 +298,26 @@ class DeviceContext : public WrapperBase { PD_CHECK_STATUS(status); return static_cast(ptr); } + + uint64_t seed() const { + C_Status status; + auto seed_val = PD_DeviceContextGetSeed(raw_data(), &status); + PD_CHECK_STATUS(status); + return seed_val; + } + + void seed(uint64_t seed_val) const { + C_Status status; + PD_DeviceContextSetSeed(raw_data(), seed_val, &status); + PD_CHECK_STATUS(status); + } + + uint64_t random() const { + C_Status status; + auto rand_val = PD_DeviceContextGetRandom(raw_data(), &status); + PD_CHECK_STATUS(status); + return rand_val; + } }; class Scalar : public WrapperBase { diff --git a/paddle/phi/capi/lib/c_device_context.cc b/paddle/phi/capi/lib/c_device_context.cc index 96b46fbc0d4ff2..b415ece7e361d2 100644 --- a/paddle/phi/capi/lib/c_device_context.cc +++ b/paddle/phi/capi/lib/c_device_context.cc @@ -74,4 +74,32 @@ void* PD_DeviceContextAllocateTensor(const PD_DeviceContext* ctx, } } +void PD_DeviceContextSetSeed(const PD_DeviceContext* ctx, + uint64_t seed, + PD_Status* status) { + if (status) { + *status = C_SUCCESS; + } + auto dev_ctx = reinterpret_cast(ctx); + dev_ctx->GetGenerator()->SetCurrentSeed(seed); +} + +uint64_t PD_DeviceContextGetSeed(const PD_DeviceContext* ctx, + PD_Status* status) { + if (status) { + *status = C_SUCCESS; + } + auto dev_ctx = reinterpret_cast(ctx); + return dev_ctx->GetGenerator()->GetCurrentSeed(); +} + +uint64_t PD_DeviceContextGetRandom(const PD_DeviceContext* ctx, + PD_Status* status) { + if (status) { + *status = C_SUCCESS; + } + auto dev_ctx = reinterpret_cast(ctx); + return dev_ctx->GetGenerator()->Random64(); +} + PD_REGISTER_CAPI(device_context);