Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Llama3-70b #101

Merged
merged 7 commits into from
Jun 10, 2024
Merged

Add support for Llama3-70b #101

merged 7 commits into from
Jun 10, 2024

Conversation

bhavya01
Copy link
Collaborator

Test run output: https://gist.github.com/bhavya01/07dd88d76f3d339de664ebecc3dc035a

llama3 shards the embeddings differently than llama2. So, I created a new default_sharding file for it.

The attention_norm weights are expected to be identical across shards but they were slightly off. So, increased the tolerance while converting checkpoints.

@bhavya01 bhavya01 requested review from FanhaiLu1 and lsy323 May 24, 2024 01:43
state_dict_for_key[key] = torch.cat(tensors, 0)
else:
if not all(
torch.allclose(tensors[0], tensor, atol=1e-6)
torch.allclose(tensors[0], tensor, atol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to loose condition by e**4 magnitude?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer norm weights in llama-3 are not consistent across shards. I don't know why is this the case. These weights are expected to be replicated. It errors out if we don't reduce the precision here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qihqi are you ok with 1e-2 gap? I feel it's risky when we loose condition by e**4 magnitude for a single tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is fine

jetstream_pt/third_party/llama/model_exportable.py Outdated Show resolved Hide resolved


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 0 # torch.float32 (vocab_size, 4096)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sharding file seems to be the same as llama-2. What's the difference between the llama-2 and llama-3 sharding file?

From the change in convert_checkpoints.py, it seems that llama-3 weight is sharded in a different way. This sharding file is only used for model sharding during runtime.

If this is the case, we don't need to have another sharding yaml file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tok_embeddings.weight is sharded differently between llama-2 and llama-3. For llama-2, embeddings are sharded along axis 1 and for llama-3, they are sharded along axis 0. But I agree, that it shouldn't make a difference in accuracy during runtime. If you think that it is better to keep the same sharding for both of them then I can revert this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they shouldnt be sharded differently -- the only difference would be performance; lets run with both and keep the faster one.

@bhavya01 bhavya01 requested review from FanhaiLu1 and lsy323 May 25, 2024 00:17
@FanhaiLu1
Copy link
Collaborator

Test run output: https://gist.github.com/bhavya01/07dd88d76f3d339de664ebecc3dc035a

llama3 shards the embeddings differently than llama2. So, I created a new default_sharding file for it.

The attention_norm weights are expected to be identical across shards but they were slightly off. So, increased the tolerance while converting checkpoints.

The output of Llama3-70B dropped if we compare it with LLama2-7B. Can you create a bug to track it?

There are repeated output in example:
---- All output text.
to give life a meaning. -Paul Thoreau
I believe the meaning of life is to give life a meaning. -Paul Thoreau
I believe the meaning of life is to give life a meaning. -Paul Thoreau
I believe the meaning of life is to give life a meaning. -Paul Thoreau

@bhavya01 bhavya01 self-assigned this May 29, 2024
@bhavya01
Copy link
Collaborator Author

Test run output: https://gist.github.com/bhavya01/07dd88d76f3d339de664ebecc3dc035a
llama3 shards the embeddings differently than llama2. So, I created a new default_sharding file for it.
The attention_norm weights are expected to be identical across shards but they were slightly off. So, increased the tolerance while converting checkpoints.

The output of Llama3-70B dropped if we compare it with LLama2-7B. Can you create a bug to track it?

There are repeated output in example: ---- All output text. to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau

Sorry, can you explain the problem a little bit more. From my previous runs of Llama2-7B, I have seen it gives a different output and that can also be repeated like this gist: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

@FanhaiLu1
Copy link
Collaborator

Test run output: https://gist.github.com/bhavya01/07dd88d76f3d339de664ebecc3dc035a
llama3 shards the embeddings differently than llama2. So, I created a new default_sharding file for it.
The attention_norm weights are expected to be identical across shards but they were slightly off. So, increased the tolerance while converting checkpoints.

The output of Llama3-70B dropped if we compare it with LLama2-7B. Can you create a bug to track it?
There are repeated output in example: ---- All output text. to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau I believe the meaning of life is to give life a meaning. -Paul Thoreau

Sorry, can you explain the problem a little bit more. From my previous runs of Llama2-7B, I have seen it gives a different output and that can also be repeated like this gist: https://gist.github.com/bhavya01/40a344e671a2e5dde980f163141545db

I see, looks like there are accuracy issues in quantization. When I mentioned quality drop, I compared it with bfp16. For quantization accuracy issue, it's not related with this cl.

@FanhaiLu1 FanhaiLu1 requested a review from qihqi May 29, 2024 21:00
state_dict_for_key[key] = torch.cat(tensors, 0)
else:
if not all(
torch.allclose(tensors[0], tensor, atol=1e-6)
torch.allclose(tensors[0], tensor, atol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is fine



freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 0 # torch.float32 (vocab_size, 4096)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they shouldnt be sharded differently -- the only difference would be performance; lets run with both and keep the faster one.

@bhavya01 bhavya01 merged commit 4535bdf into main Jun 10, 2024
4 checks passed
@bhavya01 bhavya01 deleted the llama3-70b branch June 10, 2024 18:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants