From 200599b851bb1c8449861d3a14af90025c35b1e5 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 8 Aug 2023 01:45:30 +0800 Subject: [PATCH 1/7] Support training variance model with DS files --- configs/variance.yaml | 1 + preprocessing/variance_binarizer.py | 100 +++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/configs/variance.yaml b/configs/variance.yaml index b23c2a5a..d437729d 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -23,6 +23,7 @@ midi_smooth_width: 0.06 # in seconds binarization_args: shuffle: true num_workers: 0 + prefer_ds: false raw_data_dir: 'data/opencpop_variance/raw' binary_data_dir: 'data/opencpop_variance/binary' diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index c9009ac2..e484f392 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -1,4 +1,5 @@ import csv +import json import os import pathlib @@ -19,6 +20,8 @@ get_breathiness_pyworld ) from utils.hparams import hparams +from utils.infer_utils import resample_align_curve +from utils.pitch_utils import interp_f0 from utils.plot import distribution_to_figure os.environ["OMP_NUM_THREADS"] = "1" @@ -52,31 +55,61 @@ def __init__(self): predict_breathiness = hparams['predict_breathiness'] self.predict_variances = predict_energy or predict_breathiness self.lr = LengthRegulator().to(self.device) + self.prefer_ds = self.binarization_args['prefer_ds'] + self.cached_ds = {} + + def load_attr_from_ds(self, ds_id, name, attr, idx=0): + item_name = f'{ds_id}:{name}' + if item_name not in self.cached_ds: + ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' + if not ds_path.exists(): + return None + with open(ds_path, 'r', encoding='utf8') as f: + ds = json.load(f) + if not isinstance(ds, list): + ds = [ds] + self.cached_ds[item_name] = ds + ds = ds[idx] + else: + ds = self.cached_ds[item_name][idx] + return ds.get(attr) def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): meta_data_dict = {} + for utterance_label in csv.DictReader( open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') ): item_name = utterance_label['name'] + item_idx = int(item_name.rsplit('#', maxsplit=1)[-1]) if '#' in item_name else 0 + + def fallback(attr): + if self.prefer_ds: + value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) + else: + value = None + if value is None: + return utterance_label[attr] + temp_dict = { + 'ds_idx': item_idx, 'spk_id': spk_id, 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': utterance_label['ph_seq'].split(), - 'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()] + 'ph_seq': fallback('ph_seq').split(), + 'ph_dur': [float(x) for x in fallback('ph_dur').split()] } assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' if hparams['predict_dur']: - temp_dict['ph_num'] = [int(x) for x in utterance_label['ph_num'].split()] + temp_dict['ph_num'] = [int(x) for x in fallback('ph_num').split()] assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' if hparams['predict_pitch']: - temp_dict['note_seq'] = utterance_label['note_seq'].split() - temp_dict['note_dur'] = [float(x) for x in utterance_label['note_dur'].split()] + temp_dict['note_seq'] = fallback('note_seq').split() + temp_dict['note_dur'] = [float(x) for x in fallback('note_dur').split()] assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' assert any([note != 'rest' for note in temp_dict['note_seq']]), \ @@ -129,6 +162,9 @@ def check_coverage(self): @torch.no_grad() def process_item(self, item_name, meta_data, binarization_args): + ds_id, name = item_name.split(':', maxsplit=1) + ds_id = int(ds_id) + ds_seg_idx = meta_data['ds_idx'] seconds = sum(meta_data['ph_dur']) length = round(seconds / self.timestep) T_ph = len(meta_data['ph_seq']) @@ -154,11 +190,31 @@ def process_item(self, item_name, meta_data, binarization_args): processed_input['mel2ph'] = mel2ph.cpu().numpy() # Below: extract actual f0, convert to pitch and calculate delta pitch - waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) + if pathlib.Path(meta_data['wav_fn']).exists(): + waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) + elif not self.prefer_ds: + raise FileNotFoundError(meta_data['wav_fn']) + else: + waveform = None + + global pitch_extractor if pitch_extractor is None: pitch_extractor = initialize_pe() - f0, uv = pitch_extractor.get_pitch(waveform, length, hparams, interp_uv=True) + f0 = uv = None + if self.prefer_ds: + f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx) + if f0_seq is not None: + f0 = resample_align_curve( + np.array(f0_seq.split(), np.float32), + original_timestep=float(self.load_attr_from_ds(ds_id, name, 'f0_timestep', idx=ds_seg_idx)), + target_timestep=self.timestep, + align_length=length + ) + uv = f0 == 0 + f0 = interp_f0(f0, uv) + if f0 is None: + f0, uv = pitch_extractor.get_pitch(waveform, length, hparams, interp_uv=True) if uv.all(): # All unvoiced print(f'Skipped \'{item_name}\': empty gt f0') return None @@ -208,7 +264,20 @@ def process_item(self, item_name, meta_data, binarization_args): # Below: extract energy if hparams['predict_energy']: - energy = get_energy_librosa(waveform, length, hparams).astype(np.float32) + energy = None + if self.prefer_ds: + energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx) + if energy_seq is not None: + energy = resample_align_curve( + np.array(energy_seq.split(), np.float32), + original_timestep=float(self.load_attr_from_ds( + ds_id, name, 'energy_timestep', idx=ds_seg_idx + )), + target_timestep=self.timestep, + align_length=length + ) + if energy is None: + energy = get_energy_librosa(waveform, length, hparams).astype(np.float32) global energy_smooth if energy_smooth is None: @@ -221,7 +290,20 @@ def process_item(self, item_name, meta_data, binarization_args): # Below: extract breathiness if hparams['predict_breathiness']: - breathiness = get_breathiness_pyworld(waveform, f0 * ~uv, length, hparams).astype(np.float32) + breathiness = None + if self.prefer_ds: + breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx) + if breathiness_seq is not None: + breathiness = resample_align_curve( + np.array(breathiness_seq.split(), np.float32), + original_timestep=float(self.load_attr_from_ds( + ds_id, name, 'breathiness_timestep', idx=ds_seg_idx + )), + target_timestep=self.timestep, + align_length=length + ) + if breathiness is None: + breathiness = get_breathiness_pyworld(waveform, f0 * ~uv, length, hparams).astype(np.float32) global breathiness_smooth if breathiness_smooth is None: From 3efa01966d6443071c34df1e626abe18d9871052 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Wed, 9 Aug 2023 17:24:03 +0800 Subject: [PATCH 2/7] Prefer full name matching --- preprocessing/variance_binarizer.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index e484f392..58013bcf 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -57,21 +57,30 @@ def __init__(self): self.lr = LengthRegulator().to(self.device) self.prefer_ds = self.binarization_args['prefer_ds'] self.cached_ds = {} + self.ds_idx_sep = '#' def load_attr_from_ds(self, ds_id, name, attr, idx=0): item_name = f'{ds_id}:{name}' - if item_name not in self.cached_ds: - ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' + item_name_with_idx = f'{item_name}{self.ds_idx_sep}{idx}' + if item_name_with_idx in self.cached_ds: + ds = self.cached_ds[item_name_with_idx][0] + elif item_name in self.cached_ds: + ds = self.cached_ds[item_name][idx] + else: + ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{self.ds_idx_sep}{idx}.ds' + if ds_path.exists(): + cache_key = item_name_with_idx + else: + ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' + cache_key = item_name if not ds_path.exists(): return None with open(ds_path, 'r', encoding='utf8') as f: ds = json.load(f) if not isinstance(ds, list): ds = [ds] - self.cached_ds[item_name] = ds + self.cached_ds[cache_key] = ds ds = ds[idx] - else: - ds = self.cached_ds[item_name][idx] return ds.get(attr) def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): @@ -81,7 +90,7 @@ def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') ): item_name = utterance_label['name'] - item_idx = int(item_name.rsplit('#', maxsplit=1)[-1]) if '#' in item_name else 0 + item_idx = int(item_name.rsplit(self.ds_idx_sep, maxsplit=1)[-1]) if self.ds_idx_sep in item_name else 0 def fallback(attr): if self.prefer_ds: From a0d1f322301763e74e37f096ae5304fb4c72ca33 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Wed, 9 Aug 2023 17:36:12 +0800 Subject: [PATCH 3/7] Fix attribute error --- basics/base_binarizer.py | 24 ++++++++++++++---------- preprocessing/variance_binarizer.py | 8 ++++---- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/basics/base_binarizer.py b/basics/base_binarizer.py index b7eeca25..c39fa357 100644 --- a/basics/base_binarizer.py +++ b/basics/base_binarizer.py @@ -63,19 +63,13 @@ def __init__(self, data_dir=None, data_attrs=None): self.build_spk_map() self.items = {} + self.item_names: list = None + self._train_item_names: list = None + self._valid_item_names: list = None + self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] - # load each dataset - for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs): - self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id) - self.item_names = sorted(list(self.items.keys())) - self._train_item_names, self._valid_item_names = self.split_train_valid_set() - - if self.binarization_args['shuffle']: - random.seed(hparams['seed']) - random.shuffle(self.item_names) - def build_spk_map(self): assert isinstance(self.speakers, list), 'Speakers must be a list' assert len(self.speakers) == len(self.raw_data_dirs), \ @@ -171,6 +165,16 @@ def meta_data_iterator(self, prefix): yield item_name, meta_data def process(self): + # load each dataset + for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs): + self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id) + self.item_names = sorted(list(self.items.keys())) + self._train_item_names, self._valid_item_names = self.split_train_valid_set() + + if self.binarization_args['shuffle']: + random.seed(hparams['seed']) + random.shuffle(self.item_names) + self.binary_data_dir.mkdir(parents=True, exist_ok=True) # Copy spk_map and dictionary to binary data dir diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 58013bcf..596ed3f5 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -38,6 +38,7 @@ 'energy', # frame-level RMS (dB), float32[T_s,] 'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,] ] +DS_INDEX_SEP = '#' # These operators are used as global variables due to a PyTorch shared memory bug on Windows platforms. # See https://github.com/pytorch/pytorch/issues/100358 @@ -57,17 +58,16 @@ def __init__(self): self.lr = LengthRegulator().to(self.device) self.prefer_ds = self.binarization_args['prefer_ds'] self.cached_ds = {} - self.ds_idx_sep = '#' def load_attr_from_ds(self, ds_id, name, attr, idx=0): item_name = f'{ds_id}:{name}' - item_name_with_idx = f'{item_name}{self.ds_idx_sep}{idx}' + item_name_with_idx = f'{item_name}{DS_INDEX_SEP}{idx}' if item_name_with_idx in self.cached_ds: ds = self.cached_ds[item_name_with_idx][0] elif item_name in self.cached_ds: ds = self.cached_ds[item_name][idx] else: - ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{self.ds_idx_sep}{idx}.ds' + ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{DS_INDEX_SEP}{idx}.ds' if ds_path.exists(): cache_key = item_name_with_idx else: @@ -90,7 +90,7 @@ def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') ): item_name = utterance_label['name'] - item_idx = int(item_name.rsplit(self.ds_idx_sep, maxsplit=1)[-1]) if self.ds_idx_sep in item_name else 0 + item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 def fallback(attr): if self.prefer_ds: From c127c10599a9613c31e7f6d0fdcdfc5fbad9ed2b Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Wed, 9 Aug 2023 17:45:00 +0800 Subject: [PATCH 4/7] Fix NoneType error --- preprocessing/variance_binarizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 596ed3f5..5d81493a 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -97,8 +97,9 @@ def fallback(attr): value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) else: value = None - if value is None: - return utterance_label[attr] + if value is not None: + return value + return utterance_label[attr] temp_dict = { 'ds_idx': item_idx, @@ -206,7 +207,6 @@ def process_item(self, item_name, meta_data, binarization_args): else: waveform = None - global pitch_extractor if pitch_extractor is None: pitch_extractor = initialize_pe() @@ -221,7 +221,7 @@ def process_item(self, item_name, meta_data, binarization_args): align_length=length ) uv = f0 == 0 - f0 = interp_f0(f0, uv) + f0, _ = interp_f0(f0, uv) if f0 is None: f0, uv = pitch_extractor.get_pitch(waveform, length, hparams, interp_uv=True) if uv.all(): # All unvoiced From 5a591de274b7788d4ba4bf37b7dab5ecd90b6fca Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Wed, 9 Aug 2023 20:16:31 +0800 Subject: [PATCH 5/7] Add error message if attribute is missing --- preprocessing/variance_binarizer.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 5d81493a..e968525d 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -71,7 +71,7 @@ def load_attr_from_ds(self, ds_id, name, attr, idx=0): if ds_path.exists(): cache_key = item_name_with_idx else: - ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' + ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' cache_key = item_name if not ds_path.exists(): return None @@ -89,37 +89,40 @@ def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): for utterance_label in csv.DictReader( open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') ): + utterance_label: dict item_name = utterance_label['name'] item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 - def fallback(attr): + def require(attr): if self.prefer_ds: value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) else: value = None - if value is not None: - return value - return utterance_label[attr] + if value is None: + value = utterance_label.get(attr) + if value is None: + raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.') + return value temp_dict = { 'ds_idx': item_idx, 'spk_id': spk_id, 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': fallback('ph_seq').split(), - 'ph_dur': [float(x) for x in fallback('ph_dur').split()] + 'ph_seq': require('ph_seq').split(), + 'ph_dur': [float(x) for x in require('ph_dur').split()] } assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' if hparams['predict_dur']: - temp_dict['ph_num'] = [int(x) for x in fallback('ph_num').split()] + temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()] assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' if hparams['predict_pitch']: - temp_dict['note_seq'] = fallback('note_seq').split() - temp_dict['note_dur'] = [float(x) for x in fallback('note_dur').split()] + temp_dict['note_seq'] = require('note_seq').split() + temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()] assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' assert any([note != 'rest' for note in temp_dict['note_seq']]), \ From ffe85463d32ed0144d8672644f0a8ed1ef69f507 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 11 Aug 2023 11:24:05 +0800 Subject: [PATCH 6/7] Trim ds segment index suffix --- preprocessing/variance_binarizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index e968525d..73fe71d4 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -176,6 +176,7 @@ def check_coverage(self): @torch.no_grad() def process_item(self, item_name, meta_data, binarization_args): ds_id, name = item_name.split(':', maxsplit=1) + name = name.rsplit(DS_INDEX_SEP, maxsplit=1)[0] ds_id = int(ds_id) ds_seg_idx = meta_data['ds_idx'] seconds = sum(meta_data['ph_dur']) From fa0b3da8b8bfe3c5e9d4e74bb9ec1271836a7400 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 11 Aug 2023 11:32:29 +0800 Subject: [PATCH 7/7] Skip parameter smoothing if loading from DS file --- preprocessing/variance_binarizer.py | 34 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 73fe71d4..f868ff36 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -278,6 +278,7 @@ def process_item(self, item_name, meta_data, binarization_args): # Below: extract energy if hparams['predict_energy']: energy = None + energy_from_wav = False if self.prefer_ds: energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx) if energy_seq is not None: @@ -291,19 +292,22 @@ def process_item(self, item_name, meta_data, binarization_args): ) if energy is None: energy = get_energy_librosa(waveform, length, hparams).astype(np.float32) + energy_from_wav = True - global energy_smooth - if energy_smooth is None: - energy_smooth = SinusoidalSmoothingConv1d( - round(hparams['energy_smooth_width'] / self.timestep) - ).eval().to(self.device) - energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0] + if energy_from_wav: + global energy_smooth + if energy_smooth is None: + energy_smooth = SinusoidalSmoothingConv1d( + round(hparams['energy_smooth_width'] / self.timestep) + ).eval().to(self.device) + energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0].cpu().numpy() - processed_input['energy'] = energy.cpu().numpy() + processed_input['energy'] = energy # Below: extract breathiness if hparams['predict_breathiness']: breathiness = None + breathiness_from_wav = False if self.prefer_ds: breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx) if breathiness_seq is not None: @@ -317,15 +321,17 @@ def process_item(self, item_name, meta_data, binarization_args): ) if breathiness is None: breathiness = get_breathiness_pyworld(waveform, f0 * ~uv, length, hparams).astype(np.float32) + breathiness_from_wav = True - global breathiness_smooth - if breathiness_smooth is None: - breathiness_smooth = SinusoidalSmoothingConv1d( - round(hparams['breathiness_smooth_width'] / self.timestep) - ).eval().to(self.device) - breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0] + if breathiness_from_wav: + global breathiness_smooth + if breathiness_smooth is None: + breathiness_smooth = SinusoidalSmoothingConv1d( + round(hparams['breathiness_smooth_width'] / self.timestep) + ).eval().to(self.device) + breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0].cpu().numpy() - processed_input['breathiness'] = breathiness.cpu().numpy() + processed_input['breathiness'] = breathiness return processed_input