Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make a clear plan to support Transformer on Fluid. #6876

Closed
lcy-seso opened this issue Dec 22, 2017 · 13 comments
Closed

Make a clear plan to support Transformer on Fluid. #6876

lcy-seso opened this issue Dec 22, 2017 · 13 comments

Comments

@lcy-seso
Copy link
Contributor

lcy-seso commented Dec 22, 2017

We are going to support popular NMT models on Fluid, including but not limited to RNN search, ConvS2S, and Transformer.

I think the first important thing for us is to understand and figure out the problems.

We choose Google' Transformer as our starting point. Here I list some questions should be answered:

A tensor2tensor implementation: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

  • About the Transformer:

    • motivation: What is the problem proposed to solve.
    • how the problem is solved.
    • model architecture:
      • Are there any challenges or functional requirements for a DL framework to implement the transformer architecture?
      • Are there any tricks or key points require our framework to support a function we have not considered yet?
      • Are there any operators we have not implemented yet?
  • About Fluid (This part is relatively open currently)

    • How encoder-decoder architecture is implemented in Fluid, run it and follow its execution/codes, do you have any question and problem about it (it is ok just from the perspective of a user)
    • RNN and dynamic RNN. The transformer does not need any recurrent layer and CNN layer. But how Fluid process sequence is still crucial.
    • propose your questions about Fluid, we will collect them.

At the end of this step, we will share our notes with everyone, both about Transformer and Fluid, we can try to make it part of the document. all of us should:

  1. have an overview picture of the architecture of Transformer; (from top to bottom, not the reverse.)
  2. carry out a clear plan about how we are going to implement Transformer:
    • any special requirements for Fluid?
    • what mechanisms of fluid are we going to make use of? for example: while loop, dynamic RNN, or any other things?
  3. a checklist of operators needed: The list can be directly changed to an action list.
    • what do we already have?
    • what has not been implemented yet?

Related issue: #6821

@kavyasrinet
Copy link

@lcy-seso Thanks so much for the writeup. This is very helpful. As discussed in the Hi group, let's pick the items one by one, discuss each of them and figure out a plan.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Dec 25, 2017

The transformer also follows the encoder-decoder architecture. Encoder and decoder are stacks of many identical modules.

Computations/operators each part requires:

  • Positional Embedding.

    • Nothing special, positional embedding is the sum of the positional encoding and the common word embeddings.
  • Encoder is made of: (1) self-attention part; (2) position-wise feed-forward network part;

    1. Multi-head self-attention
      • matrix multiplication
      • dot product
      • scale
      • sequence softmax
      • residul conection (addition)
      • layer normalization
    2. Position-wise Feed-Forward networks
      • two linear transformations (matrix multiplication and add, namely, the fully connected layer)
      • RELU activation
      • residual connection (addition)
      • layer normalization
  • Decoder

    1. the operators used by the decoder are the same as that used by the encoder.

Because Transformer does not depend on any recurrence and convolutions, there are not many operators that we have not implemented yet. I think the only one is the layer normalization.

But, one difficult I think is how to implement self-attention efficiently. Theoretically, this step can be highly parallelled. I guess we can make use of while_op or dynamic RNN for the first version.

The architecture of transformer is quite simple. I guess there will be many tricks to tune it once we can successfully run the model.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Dec 25, 2017

I add a brief TODO list first for our discussion.

We can not directly move into the transformer. The basic functional requirements it depends also need to be tested and debugged first. The core idea of the transformer (high parallelization and dispensing recurrence and convolution) is also used in ConvS2S.

The list still needs priorities and time schedule. They can be done in parallel.



  • Implement RNN encoder-decoder with attention. Implement attention based on RNN encoder-decoder #6912 (I think this is a very important step right now).
    • This model is not implemented and tested yet.
    • add all the needed python wrappers.
    • verify the attention which depends on dynamic RNN.
    • profile it once it can be run smoothly.

The above steps help us to debug the framework and guarantee the basic functional requirements are ready.


  • Add the layer normalization operator.
  • add python wrapper.
  • test and verify it on a simple task like text classification.

  • Build the transformer module by module.
    • Implement the self-attention. We can just debug self-attention (make it smoothly run) first on a simple task like text classification.
    • wrap the multi-head attention.
    • wrap the masked multi-head attention for the decoder.
    • wrap the positional embeddings.
    • wrap the basic computation block:
      • wrap the multi-head self-attention + residual connect + layer normalization.
      • wrap the position-wise feed-forward networks.
    • stack the basic computation blocks.
    • add the encoder-decoder attention.
    • stack the decoder.

Not a top priority right now.

  • Also, need to learn and try Transformer in t2t to make clear of some detail tricks. (This is not the most important thing currently, but in future, we may need it)
  • The mechanism we make use of to implement self-attention, I think it is the same technique in dynamic rnn (internally dynamic rnn is built upon while_op), which can also be used to implement attention in ConvS2S. If things goes well, I think we can also support the ConvS2S model.
  • Verify beam search for the transformer.
  • Make sure the learning process can stably converge. To tuning the model, we need multi-card GPU or cluster training.
  • profile and optimize.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Dec 26, 2017

The following picture is from googleblog, which well demonstrates how the transformer works.


How transformer works.

A nice PPT about Transformer: https://nlp.stanford.edu/seminar/details/lkaiser.pdf

@pkuyym
Copy link
Contributor

pkuyym commented Dec 29, 2017

Currently, I have surveyed some implementation for Transformer.
The authoritative version is: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Besides, there is a simple and direct version implemented by PyTorch: https://github.com/jadore801120/attention-is-all-you-need-pytorch

@lcy-seso lcy-seso added the NMT label Jan 3, 2018
@lcy-seso
Copy link
Contributor Author

lcy-seso commented Jan 3, 2018

We need still a better design for Transformer. while_op is important for many seq2seq tasks, but it is not necessary for Transformer because while_op still compute timestep by timestep.

@pkuyym If you have already surveyed some implementation for the transformer, I have one question, how they implement the self-attention both time and memory efficiently? Does it need some special operator, such as "broadcast"? Is self-attention based on some elementary operators or we can just write a very specific operator for it?

@pkuyym
Copy link
Contributor

pkuyym commented Jan 3, 2018

@lcy-seso Yes, while_op is not necessary.

I only surveyed the PyTorch version carefully to check the details of Transformer.

  1. how they implement the self-attention both time and memory efficiently?

torch.bmm is the key operator in the implement of self-attention, I think parallelization between time steps is considered in the operator. I will paste some snippets to make things clear:

q_s = q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_q) x d_model
k_s = k.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_k) x d_model
v_s = v.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_v) x d_model

# treat the result as a (n_head * mb_size) size batch
q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)   # (n_head*mb_size) x len_q x d_k
k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)   # (n_head*mb_size) x len_k x d_k
v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)   # (n_head*mb_size) x len_v x d_v

outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head, 1, 1))

The implementation of attention function is:

attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
if attn_mask is not None:
    assert attn_mask.size() == attn.size(), \
        'Attention mask shape {} mismatch ' \
        'with Attention logit tensor shape ' \
        '{}.'.format(attn_mask.size(), attn.size())
    attn.data.masked_fill_(attn_mask, -float('inf'))

attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)

As we can see, the key operation is torch.bmm which is short for batch mat mul.

  1. Does it need some special operator, such as "broadcast"?
    Tile the key, value and query head times to exploit the efficient bmm operator.

  2. Is self-attention based on some elementary operators or we can just write a very specific operator for it?
    Since the above implementation takes the input as padded tensor, I think we need do some adaption to support no-padding tensor.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Jan 5, 2018

how they implement the self-attention both time and memory efficiently?

I think I understand the point. Becuase both ConS2S and transformer use the "dot product attention". This kind of attention is very special. For a single sequence, the process of "dot product" can be implemented by using the outer product, so that it is highly efficient.

But when it comes to batch computation for variable length sequence, it cannot be directly implemented by the outer product. Whether we need padding, I will think more about this. ConvS2S has the same problem.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Jan 8, 2018

Thanks to @guoshengCS , I think I can better understand how to batch compute the self-attention and attention in ConvS2S.

  1. both Transformer and ConvS2S use a special kind attention: the dot product attention which can be implemented by invoking the batched matrix multiplication twice (matmul_op).

    • Both attention weight and context vector are computed by one batched matrix multiplication, which is the main difference from the additive attention.
    • I will create a new issue to explain this because this requires modifications to the current look_up_table operator.
  2. In transformer and ConvS2S we have to pad the variable-length sequences in one batch to the same length to batch.

@pkuyym
Copy link
Contributor

pkuyym commented Jan 8, 2018

It seems that implementation based on while_op is not an optimal choice. Please consider to remove the padding part.

@lcy-seso
Copy link
Contributor Author

lcy-seso commented Jan 8, 2018

The padding cannot be removed from transformer and ConvS2S.

@lcy-seso
Copy link
Contributor Author

I close this issue first. We can discuss the problem found in a new issue or reopen this issue if needed. Thanks, everyone.

@tigerneil
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants