You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While working on #644, @msaroufim suggested to use the built-in Llama for testing the mini train recipe. I looked into it and here are the 2 main changes to be made.
Initialize freq_cis without initializing KV-Cache and causal mask
This will make it convenient for some of our training recipes (e.g. QAT) to have a mini training scripts directly in torchao, and also act as self-contained examples.
API wise, I think we can add a training flag to Transformer.setup_caches() method.
When training=False (default), the old behavior is maintained.
When training=True, only freq_cis is initialized, and in the .forward() method, we don't pass mask to TransformerBlock/Attention.
The text was updated successfully, but these errors were encountered:
While working on #644, @msaroufim suggested to use the built-in Llama for testing the mini train recipe. I looked into it and here are the 2 main changes to be made.
freq_cis
without initializing KV-Cache and causal maskao/torchao/_models/llama/model.py
Lines 153 to 170 in e7fc0ed
is_causal=True
directlyao/torchao/_models/llama/model.py
Line 247 in e7fc0ed
This will make it convenient for some of our training recipes (e.g. QAT) to have a mini training scripts directly in torchao, and also act as self-contained examples.
API wise, I think we can add a
training
flag toTransformer.setup_caches()
method.training=False
(default), the old behavior is maintained.training=True
, onlyfreq_cis
is initialized, and in the.forward()
method, we don't passmask
to TransformerBlock/Attention.The text was updated successfully, but these errors were encountered: