Skip to content

Commit

Permalink
Merge pull request #173 from kan-bayashi/support_npy_loader
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi authored Jun 28, 2020
2 parents d507595 + 2c93873 commit 88a829b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion parallel_wavegan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-

__version__ = "0.4.0"
__version__ = "0.4.1"
22 changes: 10 additions & 12 deletions parallel_wavegan/datasets/scp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from torch.utils.data import Dataset

from parallel_wavegan.utils import HDF5ScpLoader
from parallel_wavegan.utils import NpyScpLoader


def _check_feats_scp_type(feats_scp):
def _get_feats_scp_loader(feats_scp):
# read the first line of feats.scp file
with open(feats_scp) as f:
key, value = f.readlines()[0].replace("\n", "").split()
Expand All @@ -27,16 +28,19 @@ def _check_feats_scp_type(feats_scp):
value_1, value_2 = value.split(":")
if value_1.endswith(".ark"):
# kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index
return "mat"
return kaldiio.load_scp(feats_scp)
elif value_1.endswith(".h5"):
# hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats
return "hdf5"
return HDF5ScpLoader(feats_scp)
else:
raise ValueError("Not supported feats.scp type.")
else:
if value.endswith(".h5"):
# hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5
return "hdf5"
return HDF5ScpLoader(feats_scp)
elif value.endswith(".npy"):
# npy case: utt_id_1 /path/to/utt_id_1.npy
return NpyScpLoader(feats_scp)
else:
raise ValueError("Not supported feats.scp type.")

Expand Down Expand Up @@ -69,10 +73,7 @@ def __init__(self,
"""
# load scp as lazy dict
audio_loader = kaldiio.load_scp(wav_scp, segments=segments)
if _check_feats_scp_type(feats_scp) == "mat":
mel_loader = kaldiio.load_scp(feats_scp)
else:
mel_loader = HDF5ScpLoader(feats_scp)
mel_loader = _get_feats_scp_loader(feats_scp)
audio_keys = list(audio_loader.keys())
mel_keys = list(mel_loader.keys())

Expand Down Expand Up @@ -267,10 +268,7 @@ def __init__(self,
"""
# load scp as lazy dict
if _check_feats_scp_type(feats_scp) == "mat":
mel_loader = kaldiio.load_scp(feats_scp)
else:
mel_loader = HDF5ScpLoader(feats_scp)
mel_loader = _get_feats_scp_loader(feats_scp)
mel_keys = list(mel_loader.keys())

# filter by threshold
Expand Down
54 changes: 54 additions & 0 deletions parallel_wavegan/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,57 @@ def values(self):
"""Return the values of the scp file."""
for key in self.keys():
yield self[key]


class NpyScpLoader(object):
"""Loader class for a fests.scp file of npy file.
Examples:
key1 /some/path/a.npy
key2 /some/path/b.npy
key3 /some/path/c.npy
key4 /some/path/d.npy
...
>>> loader = NpyScpLoader("feats.scp")
>>> array = loader["key1"]
"""

def __init__(self, feats_scp):
"""Initialize npy scp loader.
Args:
feats_scp (str): Kaldi-style feats.scp file with npy format.
"""
with open(feats_scp) as f:
lines = [line.replace("\n", "") for line in f.readlines()]
self.data = {}
for line in lines:
key, value = line.split()
self.data[key] = value

def get_path(self, key):
"""Get npy file path for a given key."""
return self.data[key]

def __getitem__(self, key):
"""Get ndarray for a given key."""
return np.load(self.data[key])

def __len__(self):
"""Return the length of the scp file."""
return len(self.data)

def __iter__(self):
"""Return the iterator of the scp file."""
return iter(self.data)

def keys(self):
"""Return the keys of the scp file."""
return self.data.keys()

def values(self):
"""Return the values of the scp file."""
for key in self.keys():
yield self[key]
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"PyYAML>=3.12",
"tqdm>=4.26.1",
"kaldiio>=2.14.1",
"h5py>=2.10.0",
"h5py>=2.9.0",
"yq>=2.10.0",
# Fix No module named "numba.decorators"
"numba<=0.48",
Expand Down Expand Up @@ -65,7 +65,7 @@

dirname = os.path.dirname(__file__)
setup(name="parallel_wavegan",
version="0.4.0",
version="0.4.1",
url="http://github.com/kan-bayashi/ParallelWaveGAN",
author="Tomoki Hayashi",
author_email="hayashi.tomoki@g.sp.m.is.nagoya-u.ac.jp",
Expand Down

0 comments on commit 88a829b

Please sign in to comment.