Skip to content

Public Code for the paper MAE-AST: Masked Autoencoding Audio Spectrogram Transformer

Notifications You must be signed in to change notification settings

AlanBaade/MAE-AST-Public

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MAE-AST

This repository contains the code for the paper MAE-AST: Masked Autoencoding Audio Spectrogram Transformer. Pretrained checkpoints to be hosted in the coming few days.

This repository contains three folders: config, mae_ast, and s3prl. Config contains a default pre-training config for the mae-ast. The mae_ast folder contains the main code for the model, and runs under fairseq. This includes a criterion, task, data loading, and models. The s3prl folder provides the upstream model and configuration for fine-tuning the MAE-AST on Superb tasks under the S3prl repository. This repository does not include fine-tuning code for AudioSet, Librispeech, and KS2, which are instead evaluated under the SSAST library with no settings changed.

Please email abaade@utexas.edu for questions.

Pretrained Model Download

Below are the two 12-layer models used in the overall results section of the paper, with a masking ratio of 75%. Clicking the link attempts to display the model checkpoints as a text file. Use wget or open the link in a new tab and save.

Download Model Layers Masking AS ESC-50 KS2 KS1 SID ER
Checkpoint MAE-AST Patch 12 Chunked 0.306 0.900 0.979 0.958 - 0.598
Checkpoint MAE-AST Frame 12 Random 0.230 0.889 0.980 0.973 0.633 0.621

Pre-Training

Pretraining on fairseq is done as follows

Environment Setup

Run the following commands with conda to set up an environment for pretraining. This assumes that fairseq is downloaded to the home directory

conda create -n fairseq_mae_ast python=3.9
conda activate fairseq_mae_ast
pip install soundfile
cd ~/fairseq
pip install -e ./
conda install tensorboardX
conda install av -c conda-forge
pip install sortedcontainers
pip install tensorboard

Input files

The dataset code takes in a directory which contains the files train.tsv, valid.tsv, and test.tsv, containing paths to the train, valid, and test data respectively. Each of train.tsv, valid.tsv, and test.tsv are tab separated value files with a / on the first line, followed by lines with (audio file paths, tab, length in frames of that audio file). For example, train.tsv starts with:

/
/path/to/AudioSet/unbalanced/6XUF56FlKvg.mkv     479232
/path/to/data/AudioSet/unbalanced/eJS_911G6ps.mkv     477696

and test.tsv starts with:

/
/path/to/LibriSpeech/data/test-other/3331/159609/3331-159609-0002.flac       225600
/path/to/LibriSpeech/data/test-other/3331/159609/3331-159609-0021.flac       165920

The dataset expects either mkv or flac files as input.

Environment Variables

Let MAE-AST-Public be the base directory of this repository

Run the following to set up enviroment variables

conda activate fairseq_mae_ast
cd ~/MAE-AST-Public
export HYDRA_FULL_ERROR=1
data_dir=/path/to/directory_with_train_valid_test_tsv_input_files
config_dir=/path/to/MAE-AST-Public/config/pretrain
user_dir=/path/to/MAE-AST-Public/mae_ast

Pretraining commands

The following run commands overwrite the default pretrain configuration, and contain the most important settings to change.

The code for configuration settings is at the top of mae_ast/models/mae_ast.py and mae_ast/tasks/mae_ast_pretraining.py. The main model logic (model forward pass) is in the middle of mae_ast/models/mae_ast.py

Patched, Chunked Masking (SSAST), 12 Layer Encoder, 75% masking ratio

Default Model Patch (12 Layer).

fairseq-hydra-train \
  --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \
  model.encoder_layers=12 model.decoder_layers=2 \
  model.random_mask_prob=0.75 task.mask_type="chunk_mask" \
  model.ast_kernel_size_chan=16 model.ast_kernel_size_time=16 model.ast_kernel_stride_chan=16 model.ast_kernel_stride_time=16 \
  criterion.classification_weight=1 criterion.reconstruction_weight=10 \
  distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \
  common.log_interval=200 checkpoint.save_interval_updates=25000 \
  optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\
  hydra.run.dir=/path/to/output_model_directory

Frame, Random Masking, 12 Layer Encoder, 75% masking ratio

Default Model Frame (12 Layer). Changing the kernel sizes and strides determines frame vs patch models.

fairseq-hydra-train \
  --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \
  model.encoder_layers=12 model.decoder_layers=2 \
  model.random_mask_prob=0.75 task.mask_type="random_mask" \
  model.ast_kernel_size_chan=128 model.ast_kernel_size_time=2 model.ast_kernel_stride_chan=128 model.ast_kernel_stride_time=2 \
  criterion.classification_weight=1 criterion.reconstruction_weight=10 \
  distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \
  common.log_interval=200 checkpoint.save_interval_updates=25000 \
  optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\
  hydra.run.dir=/path/to/output_model_directory

Frame, Chunked Masking (Wav2Vec2), 12 Layer Encoder, 75% masking ratio

The random mask probability is 1.45 due to overlap in Wav2Vec2-style masking (specified by task.mask_type="retain_spans"), which creates an average 75% masking ratio. Set the random mask probability to 0.74 for an average of 50% masking. For all other mask types, the random mask probability directly corresponds to the amount of tokens masked.

fairseq-hydra-train \
  --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \
  model.encoder_layers=12 model.decoder_layers=2 \
  model.random_mask_prob=1.45 task.mask_type="retain_spans" \
  model.ast_kernel_size_chan=128 model.ast_kernel_size_time=2 model.ast_kernel_stride_chan=128 model.ast_kernel_stride_time=2 \
  criterion.classification_weight=1 criterion.reconstruction_weight=10 \
  distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \
  common.log_interval=200 checkpoint.save_interval_updates=25000 \
  optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\
  hydra.run.dir=/path/to/output_model_directory

Fine-Tuning

The s3prl directory contains an example for fine-tuning the MAE-AST on superb, plus a readme with specific fine-tuning settings. s3prl/mae_ast/hubconf.py takes in a checkpoint generated during pretraining and uses it on downstream tasks.

About

Public Code for the paper MAE-AST: Masked Autoencoding Audio Spectrogram Transformer

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages