Skip to content

Commit

Permalink
[ASR] wav2vec2_en, test=asr (#2637)
Browse files Browse the repository at this point in the history
* wav2vec2_en, test=asr

* wav2vec2_en, test=asr

* wav2vec2_en, test=asr
  • Loading branch information
Zth9730 authored Nov 10, 2022
1 parent 0727984 commit 8d34943
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/released_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link |
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: |
[Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - |
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 718 MB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |

### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions
Expand Down
10 changes: 6 additions & 4 deletions examples/librispeech/asr3/conf/wav2vec2ASR.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean


###########################################
# Dataloader #
###########################################
Expand All @@ -95,6 +94,12 @@ dist_sampler: True
shortest_first: True
return_lens_rate: True

############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]

###########################################
# Training #
Expand All @@ -115,6 +120,3 @@ log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
augment: True


7 changes: 5 additions & 2 deletions paddlespeech/s2t/exps/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def train_batch(self, batch_index, batch, msg):
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate)
if hasattr(train_conf, 'speech_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
Expand Down Expand Up @@ -277,7 +278,9 @@ def setup_model(self):
logger.info("Setup model!")

# setup speech augmentation for wav2vec2
self.speech_augmentation = TimeDomainSpecAugment()
if hasattr(config, 'audio_augment') and self.train:
self.speech_augmentation = TimeDomainSpecAugment(
**config.audio_augment)

if not self.train:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,11 @@ def forward(self, waveforms, lengths):

class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Arguments
---------
perturb_prob : float from 0 to 1
Expand Down Expand Up @@ -677,7 +674,6 @@ class TimeDomainSpecAugment(nn.Layer):
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = paddle.randn([10, 16000])
Expand Down Expand Up @@ -718,7 +714,6 @@ def __init__(

def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : tensor
Expand Down

0 comments on commit 8d34943

Please sign in to comment.