-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LLaMA tokenization issue #531
Conversation
Hello @gakada and @haileyschoelkopf , I would really appreciate it if you can explain the fix to me and why it only affects LLaMA models. |
@HaniItani it is because LLaMA tokenizer doesn't satisfy import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
model = AutoModelForCausalLM.from_pretrained('huggyllama/llama-7b', device_map='auto')
enc = lambda s: tokenizer.encode(s, add_special_tokens=False)
def test(question_enc, answer_enc):
# get b logprobs out of model(a + b)
logprobs = F.log_softmax(model(torch.tensor(question_enc + answer_enc).unsqueeze(0).cuda())['logits'], dim=-1).cpu().squeeze(0)[len(question_enc) - 1 : -1]
guess = logprobs.argmax(dim=-1)
print(f'|{tokenizer.decode(guess)}| |{tokenizer.decode(answer_enc)}|')
# assume enc(a + b) = enc(a) + enc(b)
test(enc('Paris is the capital of'), enc(' France.'))
# |France2,| | France.|
# assume enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]
sentence_enc = enc('Paris is the capital of France.')
question_enc = enc('Paris is the capital of')
answer_enc = sentence_enc[len(question_enc):]
test(question_enc, answer_enc)
# |France.| |France.| This doesn't affect other tokenizers, though in all cases it is also assumed that |
@gakada thank you very much for your elaborate response. It's all clear to me now. |
+1 |
* Fix tokenization issue in BaseLM.loglikelihood * Add a regression script * Use entire non-continuation length as context --------- Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
* Fix tokenization issue in BaseLM.loglikelihood * Add a regression script * Use entire non-continuation length as context --------- Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
As per the regression script: