Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wav2vec2 inference using custom trained checkpoint that uses cfg instead of args #4827

Closed
abarcovschi opened this issue Oct 25, 2022 · 0 comments

Comments

@abarcovschi
Copy link

🐛 Bug

After failing to run inference with a finetuned wav2vec2 .pt checkpoint on a single audio file using the recognize.py script provided in the famous issue #2561, I created my own script for this problem, which runs for me without any errors. The custom model was trained in the Hydra framework.

NOTE: this script is for using models that have a 'cfg' field defined in the dictionary returned when loading the checkpoint, in contrast to checkpoints that have an 'args' field defined. Inspiration was taken from the issue #3043, which also used a model with the 'cfg' field instead of the 'args' field.

Code sample

# run ASR inference using a wav2vec2 ASR model and a specified decoder on a single audio file.
# used for wav2vec2 ASR checkpoints that were finetuned in the Hydra framework (loaded checkpoint has 'cfg' key but no 'args' key).


import torch
import soundfile as sf
from argparse import Namespace
import torch.nn.functional as F
from omegaconf import OmegaConf
from fairseq.data import Dictionary
from fairseq.data.data_utils import post_process
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecCtc, Wav2Vec2CtcConfig


def get_config_dict(args):
    if isinstance(args, Namespace):
        # unpack Namespace into base dict obj
        args = vars(args)
    fields = Wav2Vec2CtcConfig.__dataclass_fields__
    # create dict for attributes of Wav2Vec2CtcConfig with vals taken from the same key in args, if they exist
    fields_dict = {}
    # this means Wav2Vec2CtcConfig obj fields will be overwritten with vals from args, otherwise they will be default
    for field in fields.keys():
        if field in args:
            fields_dict[field] = args[field]

    return fields_dict


def get_feature(filepath):
    def postprocess(feats, sample_rate):
        if feats.dim == 2:
            feats = feats.mean(-1)

        assert feats.dim() == 1, feats.dim()

        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
        return feats

    wav, sample_rate = sf.read(filepath)
    feats = torch.from_numpy(wav).float()
    feats = postprocess(feats, sample_rate)
    feats = feats.cuda()

    return feats


if __name__ == "__main__":
    model_path = "/path/to/checkpoint.pt"
    target_dict = Dictionary.load('/path/to/dict.ltr.txt')

    w2v = torch.load(model_path)

    args_dict = get_config_dict(w2v['cfg']['model'])
    w2v_config_obj = OmegaConf.merge(OmegaConf.structured(Wav2Vec2CtcConfig), args_dict)

    dummy_target_dict = {'target_dictionary' : target_dict.symbols}
    dummy_target_dict = Namespace(**dummy_target_dict)

    model = Wav2VecCtc.build_model(w2v_config_obj, dummy_target_dict)
    model.load_state_dict(w2v["model"], strict=True)
    model = model.cuda()
    model.eval()

    sample, input = dict(), dict()
    WAV_PATH = '/path/to/speech.wav'

    # define additional decoder args
    decoder_args = Namespace(**{'nbest': 1})
    generator = W2lViterbiDecoder(decoder_args, target_dict)

    feature = get_feature(WAV_PATH)
    input["source"] = feature.unsqueeze(0)

    padding_mask = torch.BoolTensor(input["source"].size(1)).fill_(False).unsqueeze(0)

    input["padding_mask"] = padding_mask
    sample["net_input"] = input

    models = list()
    models.append(model)

    with torch.no_grad():
        hypo = generator.generate(models, sample, prefix_tokens=None)

    hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())


    res = post_process(hyp_pieces, 'letter')
    print(res)

Environment

  • fairseq Version (e.g., 1.0 or main): 1.0.0a0+35cc605
  • flashlight version 1.0.0
  • Hydra version 1.0.7
  • PyTorch Version (e.g., 1.0): 1.12.1+cu113
  • OS (e.g., Linux): Ubuntu 20.04.3 LTS (GNU/Linux 5.4.0-122-generic x86_64)
  • How you installed fairseq (pip, source): source
  • Build command you used (if compiling from source): pip3 install packaging && pip3 install --editable ./
  • Python version: 3.8.5
  • CUDA/cuDNN version: 11.4
  • Cuda compilation tools, release 11.1, V11.1.105
  • GPU models and configuration: A6000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant