-
Notifications
You must be signed in to change notification settings - Fork 9
/
vq_phoneseg.py
executable file
·233 lines (207 loc) · 8.17 KB
/
vq_phoneseg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env python
"""
Perform phone segmentation on VQ representations.
Author: Herman Kamper
Contact: kamperh@gmail.com
Date: 2021
"""
from pathlib import Path
from tqdm import tqdm
import argparse
import numpy as np
import sys
from vqwordseg import algorithms
#-----------------------------------------------------------------------------#
# UTILITY FUNCTIONS #
#-----------------------------------------------------------------------------#
def check_argv():
"""Check the command line arguments."""
parser = argparse.ArgumentParser(
description=__doc__.strip().split("\n")[0], add_help=False
)
parser.add_argument(
"model", help="input VQ representations"
)
parser.add_argument("dataset", type=str, help="input dataset")
parser.add_argument(
"split", type=str, help="input split"
)
parser.add_argument(
"--input_format",
help="format of input VQ representations (default: %(default)s)",
choices=["npy", "txt"], default="txt"
)
parser.add_argument(
"--algorithm",
help="VQ segmentation algorithm (default: %(default)s)",
choices=["dp_penalized", "dp_penalized_n_seg", "dp_penalized_hsmm"],
default="dp_penalized"
)
parser.add_argument(
"--dur_weight", type=float,
help="the duration penalty weight; if "
"not specified, a sensible value is chosen based on the input model",
default=None
)
parser.add_argument(
"--output_tag", type=str, help="used to name the output directory; "
"if not specified, the algorithm is used",
default=None
)
parser.add_argument(
"--downsample_factor", type=int,
help="factor by which the VQ input is downsampled "
"(default: %(default)s)",
default=2
)
parser.add_argument(
"--n_frames_per_segment", type=int,
help="determines the number of segments for dp_penalized_n_seg "
"(default: %(default)s)",
default=7
)
parser.add_argument(
"--n_min_segments", type=int,
help="sets the minimum number of segments for dp_penalized_n_seg "
"(default: %(default)s)", default=0
)
parser.add_argument(
"--dur_weight_func",
choices=["neg_chorowski", "neg_log_poisson", "neg_log_hist",
"neg_log_gamma"], default="neg_chorowski",
help="function to use for penalizing duration; "
"if probabilistic, the negative log of the prior is used"
)
parser.add_argument(
"--model_eos", dest="model_eos", action="store_true",
help="model end-of-sentence"
)
# parser.add_argument(
# "--only_save_intervals", dest="only_save_intervals",
# action="store_true", help="if set, boundaries and indices are not "
# "saved as Numpy archives, only the interval text files are saved"
# )
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
#-----------------------------------------------------------------------------#
# MAIN FUNCTION #
#-----------------------------------------------------------------------------#
def main():
args = check_argv()
# Command-line arguments
segment_func = getattr(algorithms, args.algorithm)
dur_weight_func = getattr(algorithms, args.dur_weight_func)
if args.dur_weight is None:
if args.model == "vqvae":
args.dur_weight = 3
elif args.model == "vqcpc":
args.dur_weight = 20**2
elif args.model == "cpc_big":
args.dur_weight = 3
else:
assert False, "cannot set dur_weight automatically for model type"
if args.algorithm == "dp_penalized_n_seg":
args.dur_weight = 0
print(f"Algorithm: {args.algorithm}")
print(f"Duration weight: {args.dur_weight:.4f}")
if args.output_tag is None:
args.output_tag = "phoneseg_{}".format(args.algorithm)
# Directories and files
input_dir = Path("exp")/args.model/args.dataset/args.split
z_dir = input_dir/"prequant"
print("Reading: {}".format(z_dir))
assert z_dir.is_dir(), "missing directory: {}".format(z_dir)
if args.input_format == "npy":
z_fn_list = sorted(list(z_dir.glob("*.npy")))
elif args.input_format == "txt":
z_fn_list = sorted(list(z_dir.glob("*.txt")))
else:
assert False, "invalid input format"
# Read embedding matrix
embedding_fn = input_dir.parent/"embedding.npy"
print("Reading: {}".format(embedding_fn))
embedding = np.load(embedding_fn)
# Segment files one-by-one
# if not args.only_save_intervals:
# boundaries_dict = {}
# code_indices_dict = {}
output_base_dir = input_dir/args.output_tag
output_base_dir.mkdir(exist_ok=True, parents=True)
print("Writing to: {}".format(output_base_dir))
output_dir = output_base_dir/"intervals"
output_dir.mkdir(exist_ok=True, parents=True)
for input_fn in tqdm(z_fn_list):
# Read pre-quantisation representations
if args.input_format == "npy":
z = np.load(input_fn)
elif args.input_format == "txt":
z = np.loadtxt(input_fn)
# Segment
if z.ndim == 1:
# print(input_fn)
# assert False
continue
if args.algorithm == "dp_penalized_n_seg":
boundaries, code_indices = segment_func(
embedding, z, dur_weight=args.dur_weight,
n_frames_per_segment=args.n_frames_per_segment,
n_min_segments=args.n_min_segments,
dur_weight_func=dur_weight_func
)
else:
boundaries, code_indices = segment_func(
embedding, z, dur_weight=args.dur_weight,
dur_weight_func=dur_weight_func, model_eos=args.model_eos
)
# Convert boundaries to same frequency as reference
if args.downsample_factor > 1:
boundaries_upsampled = np.zeros(
len(boundaries)*args.downsample_factor, dtype=bool
)
for i, bound in enumerate(boundaries):
boundaries_upsampled[i*args.downsample_factor + 1] = bound
boundaries = boundaries_upsampled
code_indices_upsampled = []
for start, end, index in code_indices:
code_indices_upsampled.append((
start*args.downsample_factor,
end*args.downsample_factor,
index
))
code_indices = code_indices_upsampled
# Merge repeated codes (only possible for intervals > 15 frames)
i_token = 0
while i_token < len(code_indices) - 1:
cur_start, cur_end, cur_label = code_indices[i_token]
next_start, next_end, next_label = code_indices[i_token + 1]
if cur_label == next_label:
code_indices.pop(i_token)
code_indices.pop(i_token)
code_indices.insert(
i_token,
(cur_start, next_end, cur_label)
)
# print(input_fn.stem, cur_start, next_end, cur_label)
else:
i_token += 1
# Write intervals
utt_key = input_fn.stem
with open((output_dir/utt_key).with_suffix(".txt"), "w") as f:
for start, end, index in code_indices:
f.write("{:d} {:d} {:d}\n".format(start, end, index))
# if not args.only_save_intervals:
# boundaries_dict[utt_key] = boundaries
# code_indices_dict[utt_key] = code_indices
# if not args.only_save_intervals:
# # Write code indices
# output_fn = output_base_dir/"indices.npz"
# print("Writing: {}".format(output_fn))
# np.savez_compressed(output_fn, **code_indices_dict)
# # Write boundaries
# output_fn = output_base_dir/"boundaries.npz"
# print("Writing: {}".format(output_fn))
# np.savez_compressed(output_fn, **boundaries_dict)
if __name__ == "__main__":
main()