Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on AudioSet #37

Open
nandacv opened this issue Aug 2, 2023 · 3 comments
Open

Inference on AudioSet #37

nandacv opened this issue Aug 2, 2023 · 3 comments

Comments

@nandacv
Copy link

nandacv commented Aug 2, 2023

Thank you for the code and inference script.
I understand that the PaSST model has been trained on AudioSet with sampling rate of 32kHz.
I am trying to make inference using the pre trained model.
Could you please let me know if I have to retrain the model with AudioSet (sampling rate of 16kHz) data to use it to make inference on 16kHz data or is there any other way?

Also, curious to know why did you use 32kHz instead of already available 16kHz AudioSet data?

Thanks in advance.

@kkoutini
Copy link
Owner

kkoutini commented Aug 7, 2023

Hi, thank you!

I think in order to get the best performance, it's better to retrain on 16khz.
Alternatively, you can adapt the pre-trained model to accept 16khz input like this:

First get the models as usual:

from hear21passt.base import get_basic_model, get_model_passt


model = get_basic_model(mode="logits")

Then replace the mel layer with this adapted config:


from hear21passt.models.preprocess import AugmentMelSTFT

model.mel =  AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48,
                         timem=192,
                         htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                         fmax_aug_range=1000)

you can comapre it with original mel layer here: https://github.com/kkoutini/passt_hear21/blob/4dd6b9e426f528e2e8409b9bacecf58a2f464548/hear21passt/base.py#L52
The main difference were in the original: sr=32000, win_length=800, hopsize=320, n_fft=1024
I hope this helps.

The audio files I downloaded where in 32khz

@nandacv
Copy link
Author

nandacv commented Aug 7, 2023

Thank you for the reply.
Can you please confirm if the following code looks good?

from hear21passt.base import get_basic_model,get_model_passt
import torch
#get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
from hear21passt.models.preprocess import AugmentMelSTFT
model.mel = AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=1000)
#example inference
model.eval()
with torch.no_grad():
#audio_wave has the shape of [batch, seconds*16000] sampling rate is 16k
#example audio_wave of batch=3 and 10 seconds
audio = torch.ones((3, 16000 * 10))*0.5
logits=model(audio)

Also I assume, these logits should be followed by application of sigmoid function to get the output classes? Please correct me if I am wrong.

Thanks in advance.

@kkoutini
Copy link
Owner

kkoutini commented Aug 8, 2023

yes, this looks correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants