Skip to content

Commit

Permalink
Merge pull request #286 from HazarathKumarM/hk/mel_filter_bank_hip
Browse files Browse the repository at this point in the history
Audio HIP PR6 - Mel Filter Bank HIP Support
  • Loading branch information
r-abishek authored Aug 7, 2024
2 parents 5c3772a + f9e70ec commit 6169ae3
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 113 deletions.
73 changes: 71 additions & 2 deletions include/rppdefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ SOFTWARE.
} \
} while (0)

#ifdef HIP_COMPILE
#define RPP_HOST_DEVICE __host__ __device__
#else
#define RPP_HOST_DEVICE
#endif

const float ONE_OVER_6 = 1.0f / 6;
const float ONE_OVER_3 = 1.0f / 3;
const float ONE_OVER_255 = 1.0f / 255;
Expand Down Expand Up @@ -145,7 +151,9 @@ typedef enum
/*! \brief Scratch memory size needed is beyond the bounds (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_OUT_OF_BOUND_SCRATCH_MEMORY_SIZE = -22,
/*! \brief Number of src dims is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_SRC_DIMS = -23
RPP_ERROR_INVALID_SRC_DIMS = -23,
/*! \brief Number of dst dims is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */
RPP_ERROR_INVALID_DST_DIMS = -24
} RppStatus;

/*! \brief RPP rppStatus_t type enums
Expand Down Expand Up @@ -738,6 +746,67 @@ typedef struct RpptResamplingWindow
__m128 pCenter, pScale;
} RpptResamplingWindow;

/*! \brief Base class for Mel scale conversions.
* \ingroup group_rppdefs
*/
struct BaseMelScale
{
public:
inline RPP_HOST_DEVICE virtual Rpp32f hz_to_mel(Rpp32f hz) = 0;
inline RPP_HOST_DEVICE virtual Rpp32f mel_to_hz(Rpp32f mel) = 0;
virtual ~BaseMelScale() = default;
};

/*! \brief Derived class for HTK Mel scale conversions.
* \ingroup group_rppdefs
*/
struct HtkMelScale : public BaseMelScale
{
inline RPP_HOST_DEVICE Rpp32f hz_to_mel(Rpp32f hz) { return 1127.0f * std::log(1.0f + (hz / 700.0f)); }
inline RPP_HOST_DEVICE Rpp32f mel_to_hz(Rpp32f mel) { return 700.0f * (std::exp(mel / 1127.0f) - 1.0f); }
public:
~HtkMelScale() {};
};

/*! \brief Derived class for Slaney Mel scale conversions.
* \ingroup group_rppdefs
*/
struct SlaneyMelScale : public BaseMelScale
{
const Rpp32f freqLow = 0;
const Rpp32f fsp = 66.666667f;
const Rpp32f minLogHz = 1000.0;
const Rpp32f minLogMel = (minLogHz - freqLow) / fsp;
const Rpp32f stepLog = 0.068751777; // Equivalent to std::log(6.4) / 27.0;

const Rpp32f invMinLogHz = 0.001f;
const Rpp32f invStepLog = 1.0f / stepLog;
const Rpp32f invFsp = 1.0f / fsp;

inline RPP_HOST_DEVICE Rpp32f hz_to_mel(Rpp32f hz)
{
Rpp32f mel = 0.0f;
if (hz >= minLogHz)
mel = minLogMel + std::log(hz * invMinLogHz) * invStepLog;
else
mel = (hz - freqLow) * invFsp;

return mel;
}

inline RPP_HOST_DEVICE Rpp32f mel_to_hz(Rpp32f mel)
{
Rpp32f hz = 0.0f;
if (mel >= minLogMel)
hz = minLogHz * std::exp(stepLog * (mel - minLogMel));
else
hz = freqLow + mel * fsp;
return hz;
}
public:
~SlaneyMelScale() {};
};

/******************** HOST memory typedefs ********************/

/*! \brief RPP HOST 32-bit float memory
Expand Down Expand Up @@ -1055,7 +1124,7 @@ typedef struct
Rpp64u* dstBatchIndex;
Rpp32u* inc;
Rpp32u* dstInc;
hipMemRpp32u scratchBuf;
hipMemRpp32f scratchBufferPinned;
} memGPU;

/*! \brief RPP HIP-HOST memory management
Expand Down
22 changes: 22 additions & 0 deletions include/rppt_tensor_audio_augmentations.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,28 @@ RppStatus rppt_spectrogram_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_
*/
RppStatus rppt_mel_filter_bank_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDims, Rpp32f maxFreq, Rpp32f minFreq, RpptMelScaleFormula melFormula, Rpp32s numFilter, Rpp32f sampleRate, bool normalize, rppHandle_t rppHandle);

#ifdef GPU_SUPPORT
/*! \brief Mel filter bank augmentation on HIP backend
* \details Mel filter bank augmentation for audio data
* \param[in] srcPtr source tensor in HIP memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32, layout - NFT)
* \param[out] dstPtr destination tensor in HIP memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32, layout - NFT)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] maxFreq maximum frequency if not provided maxFreq = sampleRate / 2
* \param[in] minFreq minimum frequency
* \param[in] melFormula formula used to convert frequencies from hertz to mel and from mel to hertz (SLANEY / HTK)
* \param[in] numFilter number of mel filters
* \param[in] sampleRate sampling rate of the audio
* \param[in] normalize boolean variable that determine whether to normalize weights / not
* \param[in] rppHandle RPP HIP handle created with <tt>\ref rppCreateWithStreamAndBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_mel_filter_bank_gpu(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDims, Rpp32f maxFreq, Rpp32f minFreq, RpptMelScaleFormula melFormula, Rpp32s numFilter, Rpp32f sampleRate, bool normalize, rppHandle_t rppHandle);
#endif

/*! \brief Resample augmentation on HOST backend
* \details Resample augmentation for audio data
* \param [in] srcPtr source tensor in HOST memory
Expand Down
52 changes: 0 additions & 52 deletions src/modules/cpu/kernel/mel_filter_bank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,6 @@ SOFTWARE.
#include "rpp_cpu_simd.hpp"
#include "rpp_cpu_common.hpp"

struct BaseMelScale
{
public:
virtual Rpp32f hz_to_mel(Rpp32f hz) = 0;
virtual Rpp32f mel_to_hz(Rpp32f mel) = 0;
virtual ~BaseMelScale() = default;
};

struct HtkMelScale : public BaseMelScale
{
Rpp32f hz_to_mel(Rpp32f hz) { return 1127.0f * std::log(1.0f + (hz / 700.0f)); }
Rpp32f mel_to_hz(Rpp32f mel) { return 700.0f * (std::exp(mel / 1127.0f) - 1.0f); }
public:
~HtkMelScale() {};
};

struct SlaneyMelScale : public BaseMelScale
{
const Rpp32f freqLow = 0;
const Rpp32f fsp = 200.0 / 3.0;
const Rpp32f minLogHz = 1000.0;
const Rpp32f minLogMel = (minLogHz - freqLow) / fsp;
const Rpp32f stepLog = 0.068751777; // Equivalent to std::log(6.4) / 27.0;

const Rpp32f invMinLogHz = 1.0f / 1000.0;
const Rpp32f invStepLog = 1.0f / stepLog;
const Rpp32f invFsp = 1.0f / fsp;

Rpp32f hz_to_mel(Rpp32f hz)
{
Rpp32f mel = 0.0f;
if (hz >= minLogHz)
mel = minLogMel + std::log(hz * invMinLogHz) * invStepLog;
else
mel = (hz - freqLow) * invFsp;

return mel;
}

Rpp32f mel_to_hz(Rpp32f mel)
{
Rpp32f hz = 0.0f;
if (mel >= minLogMel)
hz = minLogHz * std::exp(stepLog * (mel - minLogMel));
else
hz = freqLow + mel * fsp;
return hz;
}
public:
~SlaneyMelScale() {};
};

RppStatus mel_filter_bank_host_tensor(Rpp32f *srcPtr,
RpptDescPtr srcDescPtr,
Rpp32f *dstPtr,
Expand Down
2 changes: 2 additions & 0 deletions src/modules/hip/handlehip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ struct HandleImpl
- 293 is the size required for storing reduction outputs for 600000 size sample
- 128 is the size required for storing cutOffDB values for batch size 128 */
hipMalloc(&(this->initHandle->mem.mgpu.scratchBufferHip.floatmem), sizeof(Rpp32f) * 76853888);
hipHostMalloc(&(this->initHandle->mem.mgpu.scratchBufferPinned.floatmem), sizeof(Rpp32f) * 8294400);
}
};

Expand Down Expand Up @@ -362,6 +363,7 @@ void Handle::rpp_destroy_object_gpu()

hipFree(this->GetInitHandle()->mem.mgpu.rgbArr.rgbmem);
hipFree(this->GetInitHandle()->mem.mgpu.scratchBufferHip.floatmem);
hipHostFree(this->GetInitHandle()->mem.mgpu.scratchBufferPinned.floatmem);
}

void Handle::rpp_destroy_object_host()
Expand Down
1 change: 1 addition & 0 deletions src/modules/hip/hip_tensor_audio_augmentations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ SOFTWARE.

#include "kernel/non_silent_region_detection.hpp"
#include "kernel/down_mixing.hpp"
#include "kernel/mel_filter_bank.hpp"
#include "kernel/to_decibels.hpp"

#endif // HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP
160 changes: 160 additions & 0 deletions src/modules/hip/kernel/mel_filter_bank.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include <hip/hip_runtime.h>
#include "rpp_hip_common.hpp"

__device__ __forceinline__ void compute_mel(float *srcPtr, int melBin, float *weightsDown, int *intervals, int2 fftStrides, float normFactor, float &dstVal)
{
dstVal = 0;
//start and end FFT bin indices for the current mel bin
int fftbin = intervals[melBin];
int fftBinEnd = intervals[melBin + 1];

float *srcPtrTemp = srcPtr + fftbin * fftStrides.x + fftStrides.y;
// Process the first interval of FFT bins, applying the weights up
for (; fftbin < fftBinEnd; fftbin++, srcPtrTemp += fftStrides.x)
{
float weightUp = 1.0f - weightsDown[fftbin];
weightUp *= normFactor;
dstVal += *srcPtrTemp * weightUp;
}

fftBinEnd = intervals[melBin + 2]; // Update the end FFT bin index for the next interval
srcPtrTemp = srcPtr + fftbin * fftStrides.x + fftStrides.y;

// Process the second interval of FFT bins, applying the weights down
for (; fftbin < fftBinEnd; fftbin++, srcPtrTemp += fftStrides.x)
{
float weightDown = weightsDown[fftbin];
weightDown *= normFactor;
dstVal += *srcPtrTemp * weightDown;
}
}

__global__ void mel_filter_bank_tensor(float *srcPtr,
uint2 srcStridesNH,
float *dstPtr,
uint2 dstStridesNH,
int *srcDimsTensor,
int numFilter,
bool normalize,
float *normFactors,
float *weightsDown,
int *intervals)
{
int id_x = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z;

if (id_x >= srcDimsTensor[id_z * 2 + 1] || id_y >= numFilter)
return;

uint dstIdx = id_z * dstStridesNH.x + id_y * dstStridesNH.y + id_x;
uint srcIdx = id_z * srcStridesNH.x;

float normFactor = (normalize) ? normFactors[id_y] : 1;
compute_mel(srcPtr + srcIdx, id_y, weightsDown, intervals, make_int2(srcStridesNH.y, id_x), normFactor, dstPtr[dstIdx]);
}

RppStatus hip_exec_mel_filter_bank_tensor(Rpp32f *srcPtr,
RpptDescPtr srcDescPtr,
Rpp32f *dstPtr,
RpptDescPtr dstDescPtr,
Rpp32s* srcDimsTensor,
Rpp32f maxFreqVal,
Rpp32f minFreqVal,
RpptMelScaleFormula melFormula,
Rpp32s numFilter,
Rpp32f sampleRate,
bool normalize,
rpp::Handle& handle)
{
// Create an instance of the MelScale class based on the chosen formula
BaseMelScale *melScalePtr;
switch (melFormula)
{
case RpptMelScaleFormula::HTK:
melScalePtr = new HtkMelScale;
break;
case RpptMelScaleFormula::SLANEY:
default:
melScalePtr = new SlaneyMelScale();
break;
}

Rpp32f maxFreq = sampleRate / 2;
Rpp32f minFreq = minFreqVal;

// Convert the frequency range to Mel scale and compute Mel step size
Rpp64f melLow = melScalePtr->hz_to_mel(minFreq);
Rpp64f melHigh = melScalePtr->hz_to_mel(maxFreq);
Rpp64f melStep = (melHigh - melLow) / (numFilter + 1);

Rpp32f *scratchMem = handle.GetInitHandle()->mem.mgpu.scratchBufferPinned.floatmem;
Rpp32f *normFactors = scratchMem;
Rpp32f *weightsDown = scratchMem + numFilter;
Rpp32s *intervals = reinterpret_cast<Rpp32s *>(weightsDown + srcDescPtr->h);

// parameters for FFT and frequency bins
Rpp32s nfft = (srcDescPtr->h - 1) * 2;
Rpp32s numBins = nfft / 2 + 1;
Rpp64f hzStep = static_cast<Rpp64f>(sampleRate) / nfft;
Rpp64f invHzStep = 1.0 / hzStep;

// start and end bins for the Mel filter bank
Rpp32s fftBinStart = std::ceil(minFreq * invHzStep);
Rpp32s fftBinEnd = std::ceil(maxFreq * invHzStep);
fftBinEnd = std::min(fftBinEnd, numBins);

// Initialize arrays used for Mel filter bank computation
std::fill(normFactors, normFactors + numFilter, 1.0f);
memset(weightsDown, 0, sizeof(srcDescPtr->h * sizeof(Rpp32f)));
std::fill(intervals, intervals + numFilter + 2, -1);

// Compute Mel filter weights and intervals
Rpp32s fftBin = fftBinStart;
Rpp64f mel0 = melLow, mel1 = melLow + melStep;
Rpp64f fIter = fftBin * hzStep;

intervals[0] = fftBinStart;
intervals[numFilter + 1] = fftBinEnd;

for (int interval = 1, index = 0; index < numFilter + 1; interval++, index++, mel0 = mel1, mel1 += melStep)
{
Rpp64f f0 = melScalePtr->mel_to_hz(mel0);
Rpp64f f1 = melScalePtr->mel_to_hz(index == numFilter ? melHigh : mel1);
Rpp64f slope = 1.0 / (f1 - f0);
intervals[interval] = std::ceil(f1 / hzStep);

if (normalize && index < numFilter)
{
Rpp64f f2 = melScalePtr->mel_to_hz(mel1 + melStep);
normFactors[index] = 2.0 / (f2 - f0);
}

// Compute weights for each filter bank
for (; fftBin < fftBinEnd && fIter < f1; fftBin++, fIter = fftBin * hzStep) {
weightsDown[fftBin] = (f1 - fIter) * slope;
}
}

Rpp32s globalThreads_x = dstDescPtr->w;
Rpp32s globalThreads_y = dstDescPtr->h;
Rpp32s globalThreads_z = dstDescPtr->n;
hipLaunchKernelGGL(mel_filter_bank_tensor,
dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)),
dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z),
0,
handle.GetStream(),
srcPtr,
make_uint2(srcDescPtr->strides.nStride, srcDescPtr->strides.hStride),
dstPtr,
make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride),
srcDimsTensor,
numFilter,
normalize,
normFactors,
weightsDown,
intervals);

delete melScalePtr;
return RPP_SUCCESS;
}
4 changes: 2 additions & 2 deletions src/modules/hip/kernel/normalize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1808,8 +1808,8 @@ RppStatus hip_exec_normalize_tensor(T *srcPtr,

// create buffer for paramShape and paramStride needed for generic kernel
Rpp32u *paramShape, *paramStrides;
paramShape = handle.GetInitHandle()->mem.mgpu.scratchBuf.uintmem;
paramStrides = handle.GetInitHandle()->mem.mgpu.scratchBuf.uintmem + (batchSize * tensorDims);
paramShape = reinterpret_cast<Rpp32u*>(handle.GetInitHandle()->mem.mgpu.scratchBufferPinned.floatmem);
paramStrides = reinterpret_cast<Rpp32u*>(handle.GetInitHandle()->mem.mgpu.scratchBufferPinned.floatmem) + (batchSize * tensorDims);

// do initial preprocessing, compute maxParamVolue and fill the values for paramShape and paramStrides
Rpp32u maxParamVolume;
Expand Down
Loading

0 comments on commit 6169ae3

Please sign in to comment.