You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After failing to run inference with a finetuned wav2vec2 .pt checkpoint on a single audio file using the recognize.py script provided in the famous issue #2561, I created my own script for this problem, which runs for me without any errors. The custom model was trained in the Hydra framework.
NOTE: this script is for using models that have a 'cfg' field defined in the dictionary returned when loading the checkpoint, in contrast to checkpoints that have an 'args' field defined. Inspiration was taken from the issue #3043, which also used a model with the 'cfg' field instead of the 'args' field.
Code sample
# run ASR inference using a wav2vec2 ASR model and a specified decoder on a single audio file.# used for wav2vec2 ASR checkpoints that were finetuned in the Hydra framework (loaded checkpoint has 'cfg' key but no 'args' key).importtorchimportsoundfileassffromargparseimportNamespaceimporttorch.nn.functionalasFfromomegaconfimportOmegaConffromfairseq.dataimportDictionaryfromfairseq.data.data_utilsimportpost_processfromexamples.speech_recognition.w2l_decoderimportW2lViterbiDecoderfromfairseq.models.wav2vec.wav2vec2_asrimportWav2VecCtc, Wav2Vec2CtcConfigdefget_config_dict(args):
ifisinstance(args, Namespace):
# unpack Namespace into base dict objargs=vars(args)
fields=Wav2Vec2CtcConfig.__dataclass_fields__# create dict for attributes of Wav2Vec2CtcConfig with vals taken from the same key in args, if they existfields_dict= {}
# this means Wav2Vec2CtcConfig obj fields will be overwritten with vals from args, otherwise they will be defaultforfieldinfields.keys():
iffieldinargs:
fields_dict[field] =args[field]
returnfields_dictdefget_feature(filepath):
defpostprocess(feats, sample_rate):
iffeats.dim==2:
feats=feats.mean(-1)
assertfeats.dim() ==1, feats.dim()
withtorch.no_grad():
feats=F.layer_norm(feats, feats.shape)
returnfeatswav, sample_rate=sf.read(filepath)
feats=torch.from_numpy(wav).float()
feats=postprocess(feats, sample_rate)
feats=feats.cuda()
returnfeatsif__name__=="__main__":
model_path="/path/to/checkpoint.pt"target_dict=Dictionary.load('/path/to/dict.ltr.txt')
w2v=torch.load(model_path)
args_dict=get_config_dict(w2v['cfg']['model'])
w2v_config_obj=OmegaConf.merge(OmegaConf.structured(Wav2Vec2CtcConfig), args_dict)
dummy_target_dict= {'target_dictionary' : target_dict.symbols}
dummy_target_dict=Namespace(**dummy_target_dict)
model=Wav2VecCtc.build_model(w2v_config_obj, dummy_target_dict)
model.load_state_dict(w2v["model"], strict=True)
model=model.cuda()
model.eval()
sample, input=dict(), dict()
WAV_PATH='/path/to/speech.wav'# define additional decoder argsdecoder_args=Namespace(**{'nbest': 1})
generator=W2lViterbiDecoder(decoder_args, target_dict)
feature=get_feature(WAV_PATH)
input["source"] =feature.unsqueeze(0)
padding_mask=torch.BoolTensor(input["source"].size(1)).fill_(False).unsqueeze(0)
input["padding_mask"] =padding_masksample["net_input"] =inputmodels=list()
models.append(model)
withtorch.no_grad():
hypo=generator.generate(models, sample, prefix_tokens=None)
hyp_pieces=target_dict.string(hypo[0][0]["tokens"].int().cpu())
res=post_process(hyp_pieces, 'letter')
print(res)
Environment
fairseq Version (e.g., 1.0 or main): 1.0.0a0+35cc605
flashlight version 1.0.0
Hydra version 1.0.7
PyTorch Version (e.g., 1.0): 1.12.1+cu113
OS (e.g., Linux): Ubuntu 20.04.3 LTS (GNU/Linux 5.4.0-122-generic x86_64)
How you installed fairseq (pip, source): source
Build command you used (if compiling from source): pip3 install packaging && pip3 install --editable ./
Python version: 3.8.5
CUDA/cuDNN version: 11.4
Cuda compilation tools, release 11.1, V11.1.105
GPU models and configuration: A6000
The text was updated successfully, but these errors were encountered:
🐛 Bug
After failing to run inference with a finetuned wav2vec2 .pt checkpoint on a single audio file using the recognize.py script provided in the famous issue #2561, I created my own script for this problem, which runs for me without any errors. The custom model was trained in the Hydra framework.
NOTE: this script is for using models that have a 'cfg' field defined in the dictionary returned when loading the checkpoint, in contrast to checkpoints that have an 'args' field defined. Inspiration was taken from the issue #3043, which also used a model with the 'cfg' field instead of the 'args' field.
Code sample
Environment
pip
, source): sourceThe text was updated successfully, but these errors were encountered: