Skip to content

Commit

Permalink
Merge pull request #2415 from Zth9730/u2++_decoder
Browse files Browse the repository at this point in the history
[s2t] support bitransformer decoder
  • Loading branch information
zh794390558 authored Sep 22, 2022
2 parents 52af86f + d3e5937 commit 1a1ce92
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 37 deletions.
121 changes: 112 additions & 9 deletions paddlespeech/audio/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def has_tensor(val):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
Expand Down Expand Up @@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# _sos = paddle.to_tensor(
# [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# _eos = paddle.to_tensor(
# [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place)
# ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys_in = [paddle.concat([_sos, y], axis=0) for y in ys]
# ys_out = [paddle.concat([y, _eos], axis=0) for y in ys]
# return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0])

B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
Expand Down Expand Up @@ -190,3 +190,106 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)


def reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
pad_value: float=-1.0) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
return r_ys_pad


def st_reverse_pad_list(ys_pad: paddle.Tensor,
ys_lens: paddle.Tensor,
sos: float,
eos: float) -> paddle.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)

index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask

# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out

r_hyps = paddle_gather(ys_pad, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])

r_hyps = paddle.cat([_sos, r_hyps], dim=1)
# r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
return r_hyps
5 changes: 3 additions & 2 deletions paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config, args):
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)

self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
Expand Down Expand Up @@ -89,7 +89,8 @@ def run(self):
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
Expand Down
6 changes: 4 additions & 2 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def setup_model(self):
model_conf.output_dim = self.test_loader.vocab_size

model = U2Model.from_config(model_conf)

if self.parallel:
model = paddle.DataParallel(model)

Expand Down Expand Up @@ -317,6 +316,7 @@ def __init__(self, config, args):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)

def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
Expand All @@ -341,6 +341,7 @@ def compute_metrics(self,

start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature)

result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
Expand All @@ -350,7 +351,8 @@ def compute_metrics(self,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming)
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
decode_time = time.time() - start_time

for utt, target, result, rec_tids in zip(
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_dataloader(mode: str, config, args):
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0
Expand Down
Loading

0 comments on commit 1a1ce92

Please sign in to comment.