Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support training variance models from DS files #132

Merged
merged 7 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions basics/base_binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,13 @@ def __init__(self, data_dir=None, data_attrs=None):
self.build_spk_map()

self.items = {}
self.item_names: list = None
self._train_item_names: list = None
self._valid_item_names: list = None

self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list())
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']

# load each dataset
for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
self.item_names = sorted(list(self.items.keys()))
self._train_item_names, self._valid_item_names = self.split_train_valid_set()

if self.binarization_args['shuffle']:
random.seed(hparams['seed'])
random.shuffle(self.item_names)

def build_spk_map(self):
assert isinstance(self.speakers, list), 'Speakers must be a list'
assert len(self.speakers) == len(self.raw_data_dirs), \
Expand Down Expand Up @@ -171,6 +165,16 @@ def meta_data_iterator(self, prefix):
yield item_name, meta_data

def process(self):
# load each dataset
for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
self.item_names = sorted(list(self.items.keys()))
self._train_item_names, self._valid_item_names = self.split_train_valid_set()

if self.binarization_args['shuffle']:
random.seed(hparams['seed'])
random.shuffle(self.item_names)

self.binary_data_dir.mkdir(parents=True, exist_ok=True)

# Copy spk_map and dictionary to binary data dir
Expand Down
1 change: 1 addition & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ midi_smooth_width: 0.06 # in seconds
binarization_args:
shuffle: true
num_workers: 0
prefer_ds: false

raw_data_dir: 'data/opencpop_variance/raw'
binary_data_dir: 'data/opencpop_variance/binary'
Expand Down
155 changes: 128 additions & 27 deletions preprocessing/variance_binarizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import json
import os
import pathlib

Expand All @@ -19,6 +20,8 @@
get_breathiness_pyworld
)
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve
from utils.pitch_utils import interp_f0
from utils.plot import distribution_to_figure

os.environ["OMP_NUM_THREADS"] = "1"
Expand All @@ -35,6 +38,7 @@
'energy', # frame-level RMS (dB), float32[T_s,]
'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,]
]
DS_INDEX_SEP = '#'

# These operators are used as global variables due to a PyTorch shared memory bug on Windows platforms.
# See https://github.com/pytorch/pytorch/issues/100358
Expand All @@ -52,31 +56,73 @@ def __init__(self):
predict_breathiness = hparams['predict_breathiness']
self.predict_variances = predict_energy or predict_breathiness
self.lr = LengthRegulator().to(self.device)
self.prefer_ds = self.binarization_args['prefer_ds']
self.cached_ds = {}

def load_attr_from_ds(self, ds_id, name, attr, idx=0):
item_name = f'{ds_id}:{name}'
item_name_with_idx = f'{item_name}{DS_INDEX_SEP}{idx}'
if item_name_with_idx in self.cached_ds:
ds = self.cached_ds[item_name_with_idx][0]
elif item_name in self.cached_ds:
ds = self.cached_ds[item_name][idx]
else:
ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{DS_INDEX_SEP}{idx}.ds'
if ds_path.exists():
cache_key = item_name_with_idx
else:
ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds'
cache_key = item_name
if not ds_path.exists():
return None
with open(ds_path, 'r', encoding='utf8') as f:
ds = json.load(f)
if not isinstance(ds, list):
ds = [ds]
self.cached_ds[cache_key] = ds
ds = ds[idx]
return ds.get(attr)

def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
meta_data_dict = {}

for utterance_label in csv.DictReader(
open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8')
):
utterance_label: dict
item_name = utterance_label['name']
item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0

def require(attr):
if self.prefer_ds:
value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx)
else:
value = None
if value is None:
value = utterance_label.get(attr)
if value is None:
raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.')
return value

temp_dict = {
'ds_idx': item_idx,
'spk_id': spk_id,
'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'),
'ph_seq': utterance_label['ph_seq'].split(),
'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()]
'ph_seq': require('ph_seq').split(),
'ph_dur': [float(x) for x in require('ph_dur').split()]
}

assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \
f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.'

if hparams['predict_dur']:
temp_dict['ph_num'] = [int(x) for x in utterance_label['ph_num'].split()]
temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()]
assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \
f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.'

if hparams['predict_pitch']:
temp_dict['note_seq'] = utterance_label['note_seq'].split()
temp_dict['note_dur'] = [float(x) for x in utterance_label['note_dur'].split()]
temp_dict['note_seq'] = require('note_seq').split()
temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()]
assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \
f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.'
assert any([note != 'rest' for note in temp_dict['note_seq']]), \
Expand Down Expand Up @@ -129,6 +175,10 @@ def check_coverage(self):

@torch.no_grad()
def process_item(self, item_name, meta_data, binarization_args):
ds_id, name = item_name.split(':', maxsplit=1)
name = name.rsplit(DS_INDEX_SEP, maxsplit=1)[0]
ds_id = int(ds_id)
ds_seg_idx = meta_data['ds_idx']
seconds = sum(meta_data['ph_dur'])
length = round(seconds / self.timestep)
T_ph = len(meta_data['ph_seq'])
Expand All @@ -154,11 +204,30 @@ def process_item(self, item_name, meta_data, binarization_args):
processed_input['mel2ph'] = mel2ph.cpu().numpy()

# Below: extract actual f0, convert to pitch and calculate delta pitch
waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
if pathlib.Path(meta_data['wav_fn']).exists():
waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
elif not self.prefer_ds:
raise FileNotFoundError(meta_data['wav_fn'])
else:
waveform = None

global pitch_extractor
if pitch_extractor is None:
pitch_extractor = initialize_pe()
f0, uv = pitch_extractor.get_pitch(waveform, length, hparams, interp_uv=True)
f0 = uv = None
if self.prefer_ds:
f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx)
if f0_seq is not None:
f0 = resample_align_curve(
np.array(f0_seq.split(), np.float32),
original_timestep=float(self.load_attr_from_ds(ds_id, name, 'f0_timestep', idx=ds_seg_idx)),
target_timestep=self.timestep,
align_length=length
)
uv = f0 == 0
f0, _ = interp_f0(f0, uv)
if f0 is None:
f0, uv = pitch_extractor.get_pitch(waveform, length, hparams, interp_uv=True)
if uv.all(): # All unvoiced
print(f'Skipped \'{item_name}\': empty gt f0')
return None
Expand Down Expand Up @@ -208,29 +277,61 @@ def process_item(self, item_name, meta_data, binarization_args):

# Below: extract energy
if hparams['predict_energy']:
energy = get_energy_librosa(waveform, length, hparams).astype(np.float32)

global energy_smooth
if energy_smooth is None:
energy_smooth = SinusoidalSmoothingConv1d(
round(hparams['energy_smooth_width'] / self.timestep)
).eval().to(self.device)
energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0]

processed_input['energy'] = energy.cpu().numpy()
energy = None
energy_from_wav = False
if self.prefer_ds:
energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx)
if energy_seq is not None:
energy = resample_align_curve(
np.array(energy_seq.split(), np.float32),
original_timestep=float(self.load_attr_from_ds(
ds_id, name, 'energy_timestep', idx=ds_seg_idx
)),
target_timestep=self.timestep,
align_length=length
)
if energy is None:
energy = get_energy_librosa(waveform, length, hparams).astype(np.float32)
energy_from_wav = True

if energy_from_wav:
global energy_smooth
if energy_smooth is None:
energy_smooth = SinusoidalSmoothingConv1d(
round(hparams['energy_smooth_width'] / self.timestep)
).eval().to(self.device)
energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0].cpu().numpy()

processed_input['energy'] = energy

# Below: extract breathiness
if hparams['predict_breathiness']:
breathiness = get_breathiness_pyworld(waveform, f0 * ~uv, length, hparams).astype(np.float32)

global breathiness_smooth
if breathiness_smooth is None:
breathiness_smooth = SinusoidalSmoothingConv1d(
round(hparams['breathiness_smooth_width'] / self.timestep)
).eval().to(self.device)
breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0]

processed_input['breathiness'] = breathiness.cpu().numpy()
breathiness = None
breathiness_from_wav = False
if self.prefer_ds:
breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx)
if breathiness_seq is not None:
breathiness = resample_align_curve(
np.array(breathiness_seq.split(), np.float32),
original_timestep=float(self.load_attr_from_ds(
ds_id, name, 'breathiness_timestep', idx=ds_seg_idx
)),
target_timestep=self.timestep,
align_length=length
)
if breathiness is None:
breathiness = get_breathiness_pyworld(waveform, f0 * ~uv, length, hparams).astype(np.float32)
breathiness_from_wav = True

if breathiness_from_wav:
global breathiness_smooth
if breathiness_smooth is None:
breathiness_smooth = SinusoidalSmoothingConv1d(
round(hparams['breathiness_smooth_width'] / self.timestep)
).eval().to(self.device)
breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0].cpu().numpy()

processed_input['breathiness'] = breathiness

return processed_input

Expand Down