Skip to content

Commit

Permalink
Add normalization ops
Browse files Browse the repository at this point in the history
  • Loading branch information
StarsX committed Oct 13, 2021
1 parent c67f5b4 commit ee1fe00
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 7 deletions.
52 changes: 50 additions & 2 deletions XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ namespace XUSG
enum class TensorFlag
{
NONE,
MANAGED,
MANAGED
};

DEFINE_ENUM_FLAG_OPERATORS(TensorFlag);
Expand All @@ -119,7 +119,7 @@ namespace XUSG
NONE,
ALLOW_HALF_PRECISION_COMPUTATION,
DISABLE_META_COMMANDS,
DESCRIPTORS_VOLATILE,
DESCRIPTORS_VOLATILE
};

DEFINE_ENUM_FLAG_OPERATORS(ExecutionFlag);
Expand Down Expand Up @@ -521,6 +521,54 @@ namespace XUSG
uint32_t K;
};

struct BatchNormalization

{
const Tensor* pInput;
const Tensor* pMean;
const Tensor* pVariance;
const Tensor* pScale;
const Tensor* pBias;
const Tensor* pOutput;
bool Spatial;
float Epsilon;
OperatorType FusedActivationType;
const void* pFusedActivation;
};

struct MeanVarianceNormalization
{
const Tensor* pInput;
const Tensor* pScale;
const Tensor* pBias;
const Tensor* pOutput;
bool CrossChannel;
bool NormalizeVariance;
float Epsilon;
OperatorType FusedActivationType;
const void* pFusedActivation;
};

struct LocalResponseNormalization
{
const Tensor* pInput;
const Tensor* pOutput;
bool CrossChannel;
uint32_t LocalSize;
float Alpha;
float Beta;
float Bias;
};

struct LPNormalization
{
const Tensor* pInput;
const Tensor* pOutput;
uint32_t Axis;
float Epsilon;
uint32_t P;
};

//--------------------------------------------------------------------------------------
// Device
//--------------------------------------------------------------------------------------
Expand Down
99 changes: 99 additions & 0 deletions XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,100 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
pDMLDesc->K = desc.K;
};

static const auto getDMLBatchNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
{
const auto& desc = *static_cast<const BatchNormalization*>(pOpDesc);

vector<uint8_t> typedFused(0);
if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation);

dmlTypedOpDesc.resize(sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC) +
(desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0));
const auto pDMLDesc = reinterpret_cast<DML_BATCH_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast<DML_OPERATOR_DESC*>(
&dmlTypedOpDesc[sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC)]) : nullptr;

pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
pDMLDesc->MeanTensor = desc.pMean ? static_cast<const DML_TENSOR_DESC*>(desc.pMean->GetHandle()) : nullptr;
pDMLDesc->VarianceTensor = desc.pVariance ? static_cast<const DML_TENSOR_DESC*>(desc.pVariance->GetHandle()) : nullptr;
pDMLDesc->ScaleTensor = desc.pScale ? static_cast<const DML_TENSOR_DESC*>(desc.pScale->GetHandle()) : nullptr;
pDMLDesc->BiasTensor = desc.pBias ? static_cast<const DML_TENSOR_DESC*>(desc.pBias->GetHandle()) : nullptr;
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
pDMLDesc->Spatial = desc.Spatial;
pDMLDesc->Epsilon = desc.Epsilon;
pDMLDesc->FusedActivation = pDMLFused;

if (pDMLFused)
{
assert(desc.pFusedActivation);
const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC);
pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType);
pDMLFused->Desc = &dmlTypedOpDesc[offset];
memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size());
}
};

static const auto getDMLMeanVarianceNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
{
const auto& desc = *static_cast<const MeanVarianceNormalization*>(pOpDesc);

vector<uint8_t> typedFused(0);
if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation);

dmlTypedOpDesc.resize(sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC) +
(desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0));
const auto pDMLDesc = reinterpret_cast<DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast<DML_OPERATOR_DESC*>(
&dmlTypedOpDesc[sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC)]) : nullptr;

pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
pDMLDesc->ScaleTensor = desc.pScale ? static_cast<const DML_TENSOR_DESC*>(desc.pScale->GetHandle()) : nullptr;
pDMLDesc->BiasTensor = desc.pBias ? static_cast<const DML_TENSOR_DESC*>(desc.pBias->GetHandle()) : nullptr;
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
pDMLDesc->CrossChannel = desc.CrossChannel;
pDMLDesc->NormalizeVariance = desc.NormalizeVariance;
pDMLDesc->Epsilon = desc.Epsilon;
pDMLDesc->FusedActivation = pDMLFused;

if (pDMLFused)
{
assert(desc.pFusedActivation);
const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC);
pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType);
pDMLFused->Desc = &dmlTypedOpDesc[offset];
memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size());
}
};

static const auto getDMLLocalResponseNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
{
dmlTypedOpDesc.resize(sizeof(DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC));
const auto pDMLDesc = reinterpret_cast<DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
const auto& desc = *static_cast<const LocalResponseNormalization*>(pOpDesc);

pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
pDMLDesc->CrossChannel = desc.CrossChannel;
pDMLDesc->LocalSize = desc.LocalSize;
pDMLDesc->Alpha = desc.Alpha;
pDMLDesc->Beta = desc.Beta;
pDMLDesc->Bias = desc.Bias;
};

static const auto getDMLLPNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
{
dmlTypedOpDesc.resize(sizeof(DML_LP_NORMALIZATION_OPERATOR_DESC));
const auto pDMLDesc = reinterpret_cast<DML_LP_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
const auto& desc = *static_cast<const LPNormalization*>(pOpDesc);

pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
pDMLDesc->Axis = desc.Axis;
pDMLDesc->Epsilon = desc.Epsilon;
pDMLDesc->P = desc.P;
};


static const function<void(vector<uint8_t>&, const void*)> pfnGetDMLOps[] =
{
nullptr, // INVALID
Expand Down Expand Up @@ -768,6 +862,11 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
getDMLSpaceDepth, // DEPTH_TO_SPACE
getDMLTile, // TILE
getDMLTopK, // TOP_K

getDMLBatchNormalization, // BATCH_NORMALIZATION
getDMLMeanVarianceNormalization, // MEAN_VARIANCE_NORMALIZATION
getDMLLocalResponseNormalization, // LOCAL_RESPONSE_NORMALIZATION
getDMLLPNormalization, // LP_NORMALIZATION
};

pfnGetDMLOps[static_cast<uint32_t>(type)](dmlTypedOpDesc, pOpDesc);
Expand Down
5 changes: 0 additions & 5 deletions XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ namespace XUSG
com_ptr<IDMLDevice> m_device;
};

using BatchNormalization = DML_BATCH_NORMALIZATION_OPERATOR_DESC;
using MeanVarianceNormalization = DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC;
using LocalResponseNormalization = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC;
using LPNormalization = DML_LP_NORMALIZATION_OPERATOR_DESC;

using RNNOperator = DML_RNN_OPERATOR_DESC;
using LSTMOperator = DML_LSTM_OPERATOR_DESC;
using GRUOperator = DML_GRU_OPERATOR_DESC;
Expand Down

0 comments on commit ee1fe00

Please sign in to comment.