Skip to content

Commit

Permalink
Fix GRU tests (microsoft#22716)
Browse files Browse the repository at this point in the history
### Description
Many GRU tests were being skipped due to an error in
MLOperatorAuthorImpl.cpp. The issue was caused by activation function
names not being capitalized (e.g., ‘sigmoid’), while The AttrValue was
using mixed cases (e.g., ‘Sigmoid’, ‘LeakyRelu’), which resulted in an
‘unsupported activation function’ error in
DMLOperatorRecurrentNeuralNetwork.cpp.
This PR fixes the issue by making the DML EP activation function name
case-insensitive, and capitalizing the activation function names in the
tests.

ref PR: microsoft#15914
ref bug: https://dev.azure.com/microsoft/OS/_workitems/edit/44571772

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: nums11 <numsmt2@gmail.com>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent 53a4319 commit 685f5a5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,51 +127,51 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
DML_OPERATOR_DESC& desc = descs[i];
ActivationOperatorDescUnion& activationDesc = m_activationDescs[i];
desc.Desc = &activationDesc;

if (activationName == AttrValue::ActivationRelu)
if (CompareActivationName(activationName, AttrValue::ActivationRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_RELU;
}
else if (activationName == AttrValue::ActivationLeakyRelu)
}
else if (CompareActivationName(activationName, AttrValue::ActivationLeakyRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU;
activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationThresholdedRelu)
else if (CompareActivationName(activationName, AttrValue::ActivationThresholdedRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU;
activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationTanh)
}
else if (CompareActivationName(activationName, AttrValue::ActivationTanh))
{
desc.Type = DML_OPERATOR_ACTIVATION_TANH;
}
else if (activationName == AttrValue::ActivationScaledTanh)
}
else if (CompareActivationName(activationName, AttrValue::ActivationScaledTanh))
{
desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH;
activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type);
activationDesc.scaledTanh.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationSigmoid)
}
else if (CompareActivationName(activationName, AttrValue::ActivationSigmoid))
{
desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID;
}
else if (activationName == AttrValue::ActivationSigmoidHard)
}
else if (CompareActivationName(activationName, AttrValue::ActivationSigmoidHard))
{
desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID;
activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type);
activationDesc.hardSigmoid.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationElu)
}
else if (CompareActivationName(activationName, AttrValue::ActivationElu))
{
desc.Type = DML_OPERATOR_ACTIVATION_ELU;
activationDesc.elu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationSoftsign)
}
else if (CompareActivationName(activationName, AttrValue::ActivationSoftsign))
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN;
}
else if (activationName == AttrValue::ActivationSoftplus)
}
else if (CompareActivationName(activationName, AttrValue::ActivationSoftplus))
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS;
}
Expand All @@ -182,6 +182,12 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
}
}

bool CompareActivationName(std::string_view activationName, std::string_view attrValue)
{
auto comparer = [](char a, char b) {return std::tolower(a) == std::tolower(b);};
return std::equal(activationName.begin(), activationName.end(), attrValue.begin(), attrValue.end(), comparer);
}

void Compute(const MLOperatorKernelContext& kernelContext) override
{
// Assume that enough GPU work has been queued up after the RNN operator that it is worth
Expand Down
Loading

0 comments on commit 685f5a5

Please sign in to comment.