diff --git a/include/rppt_tensor_audio_augmentations.h b/include/rppt_tensor_audio_augmentations.h index fec05b615..31bb34eff 100644 --- a/include/rppt_tensor_audio_augmentations.h +++ b/include/rppt_tensor_audio_augmentations.h @@ -101,15 +101,14 @@ RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, * \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32) * \param[out] dstPtr destination tensor in HOST memory * \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32) -* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize) -* \param[in] channelsTensor number of channels in audio buffer (1D tensor in HOST memory, of size batchSize) +* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2) * \param[in] normalizeWeights bool flag to specify if normalization of weights is needed * \param[in] rppHandle RPP HOST handle created with \ref rppCreateWithBatchSize() * \return A \ref RppStatus enumeration. * \retval RPP_SUCCESS Successful completion. * \retval RPP_ERROR* Unsuccessful completion. */ -RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcLengthTensor, Rpp32s *channelsTensor, bool normalizeWeights, rppHandle_t rppHandle); +RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle); #ifdef __cplusplus } diff --git a/src/modules/cpu/kernel/down_mixing.hpp b/src/modules/cpu/kernel/down_mixing.hpp index b17da21b8..9cefc64a2 100644 --- a/src/modules/cpu/kernel/down_mixing.hpp +++ b/src/modules/cpu/kernel/down_mixing.hpp @@ -29,8 +29,7 @@ RppStatus down_mixing_host_tensor(Rpp32f *srcPtr, RpptDescPtr srcDescPtr, Rpp32f *dstPtr, RpptDescPtr dstDescPtr, - Rpp32s *srcLengthTensor, - Rpp32s *channelsTensor, + Rpp32s *srcDimsTensor, bool normalizeWeights, rpp::Handle& handle) { @@ -43,8 +42,8 @@ RppStatus down_mixing_host_tensor(Rpp32f *srcPtr, Rpp32f *srcPtrTemp = srcPtr + batchCount * srcDescPtr->strides.nStride; Rpp32f *dstPtrTemp = dstPtr + batchCount * dstDescPtr->strides.nStride; - Rpp32s channels = channelsTensor[batchCount]; - Rpp32s samples = srcLengthTensor[batchCount]; + Rpp32s samples = srcDimsTensor[batchCount * 2]; + Rpp32s channels = srcDimsTensor[batchCount * 2 + 1]; bool flagAVX = 0; if(channels == 1) diff --git a/src/modules/rppt_tensor_audio_augmentations.cpp b/src/modules/rppt_tensor_audio_augmentations.cpp index cb01e8012..d78b8890a 100644 --- a/src/modules/rppt_tensor_audio_augmentations.cpp +++ b/src/modules/rppt_tensor_audio_augmentations.cpp @@ -133,8 +133,7 @@ RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, - Rpp32s *srcLengthTensor, - Rpp32s *channelsTensor, + Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle) { @@ -144,8 +143,7 @@ RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, srcDescPtr, static_cast(dstPtr), dstDescPtr, - srcLengthTensor, - channelsTensor, + srcDimsTensor, normalizeWeights, rpp::deref(rppHandle)); diff --git a/utilities/test_suite/HOST/Tensor_host_audio.cpp b/utilities/test_suite/HOST/Tensor_host_audio.cpp index 840441525..fe6fa1246 100644 --- a/utilities/test_suite/HOST/Tensor_host_audio.cpp +++ b/utilities/test_suite/HOST/Tensor_host_audio.cpp @@ -201,15 +201,18 @@ int main(int argc, char **argv) { testCaseName = "down_mixing"; bool normalizeWeights = false; + Rpp32s srcDimsTensor[batchSize * 2]; - for (int i = 0; i < batchSize; i++) + for (int i = 0, j = 0; i < batchSize; i++, j += 2) { - srcDims[i].height = dstDims[i].height = srcLengthTensor[i]; - srcDims[i].width = dstDims[i].width = 1; + srcDimsTensor[j] = srcLengthTensor[i]; + srcDimsTensor[j + 1] = channelsTensor[i]; + dstDims[i].height = srcLengthTensor[i]; + dstDims[i].width = 1; } startWallTime = omp_get_wtime(); - rppt_down_mixing_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcLengthTensor, channelsTensor, normalizeWeights, handle); + rppt_down_mixing_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcDimsTensor, normalizeWeights, handle); break; }