Skip to content

Commit

Permalink
added validation check for srcDims
Browse files Browse the repository at this point in the history
  • Loading branch information
sampath1117 committed Jun 18, 2024
1 parent fd21921 commit 11f3e80
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion include/rppdefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ typedef enum
/*! \brief src and dst layout mismatch \ingroup group_rppdefs */
RPP_ERROR_LAYOUT_MISMATCH = -18,
/*! \brief Number of channels is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_CHANNELS = -19
RPP_ERROR_INVALID_CHANNELS = -19,
/*! \brief Number of src dims is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_SRC_DIMS = -20
} RppStatus;

/*! \brief RPP rppStatus_t type enums
Expand Down
2 changes: 2 additions & 0 deletions src/modules/hip/kernel/down_mixing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ __global__ void down_mixing_hip_tensor(float *srcPtr,
d_float8 dst_f8;
dst_f8.f4[0] = static_cast<float4>(0.0f);
dst_f8.f4[1] = dst_f8.f4[0];

// compute the output values for 8 dst locations
for (int j = 0; j < 8; j++)
{
int i = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/modules/rppt_tensor_audio_augmentations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ RppStatus rppt_down_mixing_gpu(RppPtr_t srcPtr,
rppHandle_t rppHandle)
{
#ifdef HIP_COMPILE
Rpp32u tensorDims = srcDescPtr->numDims - 1; // exclude batchsize from input dims
if (tensorDims != 2)
return RPP_ERROR_INVALID_SRC_DIMS;

if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32))
{
hip_exec_down_mixing_tensor(static_cast<Rpp32f*>(srcPtr),
Expand Down
3 changes: 3 additions & 0 deletions utilities/test_suite/HIP/Tensor_audio_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ int main(int argc, char **argv)
set_audio_descriptor_dims_and_strides(srcDescPtr, batchSize, maxSrcHeight, maxSrcWidth, maxSrcChannels, offsetInBytes);
int maxDstChannels = maxSrcChannels;
if(testCase == 3)
{
srcDescPtr->numDims = 3;
maxDstChannels = 1;
}
set_audio_descriptor_dims_and_strides(dstDescPtr, batchSize, maxDstHeight, maxDstWidth, maxDstChannels, offsetInBytes);

// set buffer sizes for src/dst
Expand Down

0 comments on commit 11f3e80

Please sign in to comment.