From 4340a1efdaddd55fec612261e6d54c841e6dc165 Mon Sep 17 00:00:00 2001 From: Yujia2415 <158730742+Yujia2415@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:47:37 +0800 Subject: [PATCH] Update Transformer.py a minor change in the comment of tensor's dimension --- 5-1.Transformer/Transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/5-1.Transformer/Transformer.py b/5-1.Transformer/Transformer.py index fd1421e..7ba9eec 100644 --- a/5-1.Transformer/Transformer.py +++ b/5-1.Transformer/Transformer.py @@ -162,7 +162,7 @@ def __init__(self): def forward(self, enc_inputs, dec_inputs): enc_outputs, enc_self_attns = self.encoder(enc_inputs) dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) - dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size] + dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x tgt_len x tgt_vocab_size] return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns def showgraph(attn): @@ -223,4 +223,4 @@ def showgraph(attn): showgraph(dec_self_attns) print('first head of last state dec_enc_attns') - showgraph(dec_enc_attns) \ No newline at end of file + showgraph(dec_enc_attns)