diff --git a/include/rppt_tensor_audio_augmentations.h b/include/rppt_tensor_audio_augmentations.h
index fec05b615..31bb34eff 100644
--- a/include/rppt_tensor_audio_augmentations.h
+++ b/include/rppt_tensor_audio_augmentations.h
@@ -101,15 +101,14 @@ RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr,
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
-* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
-* \param[in] channelsTensor number of channels in audio buffer (1D tensor in HOST memory, of size batchSize)
+* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param[in] rppHandle RPP HOST handle created with \ref rppCreateWithBatchSize()
* \return A \ref RppStatus enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
-RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcLengthTensor, Rpp32s *channelsTensor, bool normalizeWeights, rppHandle_t rppHandle);
+RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle);
#ifdef __cplusplus
}
diff --git a/src/modules/cpu/kernel/down_mixing.hpp b/src/modules/cpu/kernel/down_mixing.hpp
index b17da21b8..9cefc64a2 100644
--- a/src/modules/cpu/kernel/down_mixing.hpp
+++ b/src/modules/cpu/kernel/down_mixing.hpp
@@ -29,8 +29,7 @@ RppStatus down_mixing_host_tensor(Rpp32f *srcPtr,
RpptDescPtr srcDescPtr,
Rpp32f *dstPtr,
RpptDescPtr dstDescPtr,
- Rpp32s *srcLengthTensor,
- Rpp32s *channelsTensor,
+ Rpp32s *srcDimsTensor,
bool normalizeWeights,
rpp::Handle& handle)
{
@@ -43,8 +42,8 @@ RppStatus down_mixing_host_tensor(Rpp32f *srcPtr,
Rpp32f *srcPtrTemp = srcPtr + batchCount * srcDescPtr->strides.nStride;
Rpp32f *dstPtrTemp = dstPtr + batchCount * dstDescPtr->strides.nStride;
- Rpp32s channels = channelsTensor[batchCount];
- Rpp32s samples = srcLengthTensor[batchCount];
+ Rpp32s samples = srcDimsTensor[batchCount * 2];
+ Rpp32s channels = srcDimsTensor[batchCount * 2 + 1];
bool flagAVX = 0;
if(channels == 1)
diff --git a/src/modules/rppt_tensor_audio_augmentations.cpp b/src/modules/rppt_tensor_audio_augmentations.cpp
index cb01e8012..d78b8890a 100644
--- a/src/modules/rppt_tensor_audio_augmentations.cpp
+++ b/src/modules/rppt_tensor_audio_augmentations.cpp
@@ -133,8 +133,7 @@ RppStatus rppt_down_mixing_host(RppPtr_t srcPtr,
RpptDescPtr srcDescPtr,
RppPtr_t dstPtr,
RpptDescPtr dstDescPtr,
- Rpp32s *srcLengthTensor,
- Rpp32s *channelsTensor,
+ Rpp32s *srcDimsTensor,
bool normalizeWeights,
rppHandle_t rppHandle)
{
@@ -144,8 +143,7 @@ RppStatus rppt_down_mixing_host(RppPtr_t srcPtr,
srcDescPtr,
static_cast(dstPtr),
dstDescPtr,
- srcLengthTensor,
- channelsTensor,
+ srcDimsTensor,
normalizeWeights,
rpp::deref(rppHandle));
diff --git a/utilities/test_suite/HOST/Tensor_host_audio.cpp b/utilities/test_suite/HOST/Tensor_host_audio.cpp
index 840441525..fe6fa1246 100644
--- a/utilities/test_suite/HOST/Tensor_host_audio.cpp
+++ b/utilities/test_suite/HOST/Tensor_host_audio.cpp
@@ -201,15 +201,18 @@ int main(int argc, char **argv)
{
testCaseName = "down_mixing";
bool normalizeWeights = false;
+ Rpp32s srcDimsTensor[batchSize * 2];
- for (int i = 0; i < batchSize; i++)
+ for (int i = 0, j = 0; i < batchSize; i++, j += 2)
{
- srcDims[i].height = dstDims[i].height = srcLengthTensor[i];
- srcDims[i].width = dstDims[i].width = 1;
+ srcDimsTensor[j] = srcLengthTensor[i];
+ srcDimsTensor[j + 1] = channelsTensor[i];
+ dstDims[i].height = srcLengthTensor[i];
+ dstDims[i].width = 1;
}
startWallTime = omp_get_wtime();
- rppt_down_mixing_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcLengthTensor, channelsTensor, normalizeWeights, handle);
+ rppt_down_mixing_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcDimsTensor, normalizeWeights, handle);
break;
}