Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Baichuan2 models #1022

Closed
wants to merge 6 commits into from
Closed

Conversation

garyfanhku
Copy link

@garyfanhku garyfanhku commented Sep 12, 2023

Added Baichuan2 model and config, registered Baichuan2 as a new model. Added an offline inference example for validating generation outputs with models using chat format.
Notes:

  1. As repetition_penalty is yet to be added to 0.1.7 and Baichuan2 has default configuration of repetition penalty 1.05, text generation outputs cannot be replicated 1 to 1.
  2. Normhead() is yet to be implemented. Added NormHead in load_weight() per suggestion from @nexa123, although it does not affect the output during my testing.
  3. Tensorparallel fully supported and tested.
  4. (*Potential Bug) Text generation output seems to be prepended by a whitespace. The prepending whitespace is caused by prompts not formatted in chat format. Details see below.
  5. Currently LLM.generate() does not seem to handle chat formatted prompts like [{"role":"user", "content":"..."}], which is adopted by Baichuan 1&2 models. Therefore I added an example offline_inference_baichuan.py, fixing prompt formats by copying build_chat_input() from Baichuan2 repo and passing in prompt_token_ids directly. I also tested the output consistency against non-vLLM pipelines. vLLM generated outputs look largely consistent with Baichuan2's local inference results. Below is an example comparison.
>>>>> Baichuan2-13B-Chat FP16 vLLM 
Prompt: "The future of AI is",
 Generated text: "The future of AI is promising and full of potential. Artificial intelligence (AI) is expected to play an increasingly important role in various aspects of our lives, including healthcare, education, transportation, and entertainment. Some of the key trends in AI's future development include:

1. AI in healthcare: AI is expected to revolutionize healthcare by improving diagnostics, treatment planning, and drug discovery. AI algorithms can analyze medical images, patient data, and medical literature to identify patterns and make predictions, leading to better patient outcomes.

2. AI in education: AI-powered tools and platforms are being developed to personalize learning experiences, identify students' strengths and weaknesses, and provide real-time feedback. This can help students learn more effectively and at their own pace.

3. AI in transportation: Autonomous vehicles are expected to become a reality in the near future, with AI playing a crucial role in their development and deployment. AI can help improve traffic flow, reduce accidents, and optimize public transportation systems.

4. AI in entertainment: AI is being used to create more immersive and personalized entertainment experiences, such as video games, movies, and music. AI can also help create new forms of entertainment, such as virtual reality experiences and interactive storytelling.

5. AI ethics and regulation: As AI becomes more integrated into our lives, ethical and regulatory concerns will also grow. We can expect to see more discussions and legislation surrounding issues such as data privacy, algorithmic bias, and the potential for AI to displace human jobs.

6. AI and the job market: AI has the potential to both create new job opportunities and displace existing ones. As AI systems become more capable of performing tasks previously done by humans, some jobs may become obsolete. However, AI can also lead to the creation of new jobs, such as AI developers, data scientists, and AI ethics specialists.

7. AI and the environment: AI can play a significant role in addressing environmental challenges, such as climate change and resource depletion. AI can help optimize energy consumption, improve waste management, and monitor and predict natural disasters.

8. AI and the global economy: AI has the potential to significantly boost global economic growth by increasing productivity, reducing costs, and fostering innovation. However, the benefits of AI may not be evenly distributed, leading to increased income inequality and concerns about the digital divide.

In summary, the future of AI is full of possibilities and challenges. As AI continues to develop and become more integrated into our lives, it is crucial for individuals, businesses, and governments to understand and address the ethical, social, and economic implications of this technology."
 
>>>>>> Baichuan2-13B-Chat 8Bit Demo:
The future of AI is promising and full of potential. It is expected to revolutionize various industries, including healthcare, finance, transportation, manufacturing, and education. Some of the key areas where AI is expected to make a significant impact include:

1. Healthcare: AI will play a crucial role in improving diagnostics, treatment planning, and personalized medicine. Machine learning algorithms can analyze medical images, patient records, and genetic data to identify patterns and make accurate predictions. This will help in early detection of diseases, better treatment options, and improved patient outcomes.

2. Finance: AI will transform the financial industry by automating processes, enhancing risk management, and improving customer service. AI-powered algorithms can analyze vast amounts of financial data to detect fraud, assess credit risk, and optimize investment strategies.

3. Transportation: Autonomous vehicles are expected to become a reality in the near future, thanks to advancements in AI and machine learning. Self-driving cars have the potential to reduce traffic accidents, improve fuel efficiency, and revolutionize urban transportation systems.

4. Manufacturing: AI will enhance productivity and efficiency in the manufacturing sector by automating repetitive tasks, optimizing supply chains, and predicting equipment failures. This will enable companies to produce higher-quality products at a lower cost.

5. Education: AI will transform the way we learn and teach by providing personalized learning experiences, identifying students' strengths and weaknesses, and offering targeted support. AI-powered tools can also help educators in managing classrooms, tracking student progress, and identifying learning gaps.

8. Customer Service: AI-powered chatbots and virtual assistants will improve customer service by handling routine inquiries, providing personalized recommendations, and assisting human agents in resolving complex issues.

9. Environment and Sustainability: AI can help monitor and predict environmental changes, optimize resource usage, and develop sustainable solutions for climate change and pollution control.

However, the future of AI also comes with its challenges, such as job displacement, privacy concerns, and ethical considerations. As AI continues to advance, it is essential to address these challenges and ensure that its benefits are accessible to everyone.
>>>>>

@zhuohan123 zhuohan123 added the new model Requests to new models label Sep 12, 2023
@nexa123
Copy link

nexa123 commented Sep 13, 2023

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

@calvin1978
Copy link

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)?
VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

@calvin1978
Copy link

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

yes, we should only norm it when load the tensor once in load_weights(), just modify the original Baichuan model to do it when the model version is 2

@nexa123
Copy link

nexa123 commented Sep 13, 2023

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)? VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

No, the result of fpl16 is the same. However alibi masks under bf16 is different, espcially context is long (for example 4096)

@calvin1978
Copy link

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)? VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

No, the result of fpl16 is the same. However alibi masks under bf16 is different, espcially context is long (for example 4096)

the implement of alibi mask in VLLM, GPT-Neox and MPT just follow the original paper, which are better than the implement of Baichuan and Bloom.

@nexa123
Copy link

nexa123 commented Sep 13, 2023

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)? VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

No, the result of fpl16 is the same. However alibi masks under bf16 is different, espcially context is long (for example 4096)

the implement of alibi mask in VLLM, GPT-Neox and MPT just follow the original paper, which are better than the implement of Baichuan and Bloom.
Maybe this artitle gives the reson of baichuan's special alibi mask implementation.

@garyfanhku
Copy link
Author

garyfanhku commented Sep 13, 2023

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)? VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

Tried to hand calculate Baichuan2's alibi mask vs Baichuan1 in vLLM. Overall it doesn't look drastically different.

Baichuan2's alibi slopes are only slightly different than that in the vLLM implementation in Baichuan1 _get_alibi_slopes() (due to floating point precision?)

def _get_interleave(n):
    def _get_interleave_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return _get_interleave_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            _get_interleave_power_of_2(closest_power_of_2)
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


n_head = 32
slopes_baichuan2 = torch.Tensor(_get_interleave(n_head))
slopes_vllm_baichuan1 = _get_alibi_slopes(n_head)

slopes_vllm_baichuan1 - slopes_baichuan2
tensor([ 0.0000e+00, -5.9605e-08, -5.9605e-08,  0.0000e+00, -2.9802e-08,
        -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08,
        -2.9802e-08, -1.4901e-08, -1.4901e-08, -2.2352e-08, -1.4901e-08,
        -1.4901e-08, -1.1176e-08, -1.1176e-08, -1.1176e-08, -7.4506e-09,
        -7.4506e-09, -7.4506e-09, -7.4506e-09, -5.5879e-09, -4.6566e-09,
        -4.6566e-09, -3.7253e-09, -2.7940e-09, -2.7940e-09, -2.3283e-09,
        -2.3283e-09, -1.8626e-09])

Mask generation looks fine

# given the same precalculated slopes 
max_pos = 4096
def baichuan2_alibi(intput_slopes):
    alibi = intput_slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(
        0
    ).unsqueeze(0).expand(n_head, -1, -1)
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask
baichuan2_alibi_mask = baichuan2_alibi(slopes)
def baichuan1_alibi(input_slopes, n_head=32):
    position_point = torch.arange(max_pos) - max_pos + 1
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
    diag = torch.diag(position_point[0])
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
    alibi = input_slopes.unsqueeze(1).unsqueeze(1) * position_point
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask
baichuan1_alibi_mask = baichuan1_alibi(slopes)

torch.all(baichuan1_alibi_mask == baichuan2_alibi_mask)
tensor(True)

The mask generation process of Baichuan 2's alibi mask is consistent with 1. Why they modified the implementation is beyond me though.

@garyfanhku
Copy link
Author

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Thanks for the suggestion, I believe it is also mentioned in the migration guide: https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#migrating-inference-optimizations-from-baichuan-1-to-baichuan-2

I have pushed a new commit adding norm head, please check the updated PR comment.

@calvin1978
Copy link

Withou normhead, we can norm the lmhead weight when loading tensor. Baichuan and Baichuan2's alibi mask is different from vllm's implementation which cause divergence of outputs.

Baichuan2's alibi mask code looks different , but the result is same(same problem)? VLLM had implement it right in PagedAttentionWithALiBi, follows the original ALiBi paper

Tried to hand calculate Baichuan2's alibi mask vs Baichuan1 in vLLM. Overall it doesn't look drastically different.

Baichuan2's alibi slopes are only slightly different than that in the vLLM implementation in Baichuan1 _get_alibi_slopes() (due to floating point precision?)

def _get_interleave(n):
    def _get_interleave_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return _get_interleave_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            _get_interleave_power_of_2(closest_power_of_2)
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


n_head = 32
slopes_baichuan2 = torch.Tensor(_get_interleave(n_head))
slopes_vllm_baichuan1 = _get_alibi_slopes(n_head)

slopes_vllm_baichuan1 - slopes_baichuan2
tensor([ 0.0000e+00, -5.9605e-08, -5.9605e-08,  0.0000e+00, -2.9802e-08,
        -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08,
        -2.9802e-08, -1.4901e-08, -1.4901e-08, -2.2352e-08, -1.4901e-08,
        -1.4901e-08, -1.1176e-08, -1.1176e-08, -1.1176e-08, -7.4506e-09,
        -7.4506e-09, -7.4506e-09, -7.4506e-09, -5.5879e-09, -4.6566e-09,
        -4.6566e-09, -3.7253e-09, -2.7940e-09, -2.7940e-09, -2.3283e-09,
        -2.3283e-09, -1.8626e-09])

Mask generation looks fine

# given the same precalculated slopes 
max_pos = 4096
def baichuan2_alibi(intput_slopes):
    alibi = intput_slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(
        0
    ).unsqueeze(0).expand(n_head, -1, -1)
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask
baichuan2_alibi_mask = baichuan2_alibi(slopes)
def baichuan1_alibi(input_slopes, n_head=32):
    position_point = torch.arange(max_pos) - max_pos + 1
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
    diag = torch.diag(position_point[0])
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
    alibi = input_slopes.unsqueeze(1).unsqueeze(1) * position_point
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask
baichuan1_alibi_mask = baichuan1_alibi(slopes)

torch.all(baichuan1_alibi_mask == baichuan2_alibi_mask)
tensor(True)

The mask generation process of Baichuan 2's alibi mask is consistent with 1. Why they modified the implementation is beyond me though.

VLLM generate the mask like [-3000, -3016....-2.5,-1.1, 0] , the precision of -1~-2 is high,which belong to the closer token.

But Baichuan's mask looks like [0,1,......3000,3016,3016], the precision of 3000 is very low in bf16 format , interval is 16. the small number with higher precision are faraway from current token.

@@ -15,6 +15,8 @@
"AquilaModel": AquilaForCausalLM,
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BaiChuan2ForCausalLM": BaiChuan2ForCausalLM, # baichuan-7b
Copy link

@ericzhou571 ericzhou571 Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool work!! However I still have a small question about the PR.
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/d022d7264467b2c3bc483e7a3a17105dedba50b8/modeling_baichuan.py#L536

https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/d022d7264467b2c3bc483e7a3a17105dedba50b8/config.json#L8
According to baichuan2 offical code, they still call their model BaichuanForCausalLM.
Does that mean, if we directly use baichuan2 model download from HF repo, vllm will never load the code for Baichuan2?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I pushed a workaround by comparing vocab_size to decide which Baichuan version to call.

in _get_model_architecture()

# baichuan 2 has different vocab size
if ("baichuan" in arch.lower()) and (getattr(config, "vocab_size")
                                     == 125696):
    return Baichuan2ForCausalLM

@JaheimLee
Copy link

JaheimLee commented Sep 15, 2023

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

@garyfanhku
Copy link
Author

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

Interesting, maybe we load weights differently? Anyways it's meant for printing out the norms before & after norm head operation for a quick visual inspection, one can simply comment out these two print statements.

I've already done so in my latest commit. You can try and re-fetch this PR.

if name == "lm_head.weight":
      # print(
      #     f"loading lm_head weight, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )
      loaded_weight = torch.nn.functional.normalize(loaded_weight)
      # print(
      #     f"after normalization, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )

@chenxu2048
Copy link
Contributor

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

I noticed that the function is ended with _cpu. Some functions may not work on cpu with fp16 in pytorch. Check if the loaded_weight is torch.float32 or on GPU.

@JaheimLee @garyfanhku

@JaheimLee
Copy link

JaheimLee commented Sep 18, 2023

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

Interesting, maybe we load weights differently? Anyways it's meant for printing out the norms before & after norm head operation for a quick visual inspection, one can simply comment out these two print statements.

I've already done so in my latest commit. You can try and re-fetch this PR.

if name == "lm_head.weight":
      # print(
      #     f"loading lm_head weight, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )
      loaded_weight = torch.nn.functional.normalize(loaded_weight)
      # print(
      #     f"after normalization, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )

Oh yes, I loaded the finetuned weights by my own. I trained it using deepspeed+lora, and finally merge the adapter into the origin model. So maybe it will change the data type. The official weights dosen't raise that error.

@zhudianGG
Copy link

Hi @WoosukKwon ,

Could you please review this PR @garyfanhku #1022. The code seems to be without issues. Please review when possible.Thanks!

@garyfanhku
Copy link
Author

Hi @WoosukKwon ,

Could you please review this PR @garyfanhku #1022. The code seems to be without issues. Please review when possible.Thanks!

I've updated the code to support both Baichuan2-7B and 13B, thanks to the revision proposed in #1092

Cheers.

Generated text: 'Hello, my name is [your name]. Nice to meet you!'
Prompt: None,
Generated text: 'The current president of the United States is Joe Biden, who was sworn into office on January 20, 2021.'
>>>>>> Baichuan2-13B-Chat 8Bit Demo:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of 8Bit here? I see that the model you loaded above is Baichuan2-13B-Chat.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're just for comparing vllm's output with HF model's output. A consistency check.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

Comment on lines +53 to +56
if getattr(config, "intermediate_size") == 11008:
return BaiChuan2ForCausalLM
elif getattr(config, "intermediate_size") == 13696:
return Baichuan2ForCausalLM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

                if getattr(config, "intermediate_size") == 11008:
                    arch = "BaiChuan2ForCausalLM"
                elif getattr(config, "intermediate_size") == 13696:
                    arch = "Baichuan2ForCausalLM"

Is this better? Since you have already added it in _MODEL_REGISTRY.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems equivalent to me. Any particular benefits?

Copy link

@jessiewiswjc jessiewiswjc Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equal, but if you return directly, there is no need to add it to _MODEL_REGISTRY

Copy link

@zhouzhou0322 zhouzhou0322 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@zhudianGG
Copy link

Hi, Gary, perhaps you should consider modifying the specific line of the original code instead of rewriting it entirely. I believe this would be better for facilitating the merging of your code.

@jugglq
Copy link

jugglq commented Oct 25, 2023

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

Interesting, maybe we load weights differently? Anyways it's meant for printing out the norms before & after norm head operation for a quick visual inspection, one can simply comment out these two print statements.
I've already done so in my latest commit. You can try and re-fetch this PR.

if name == "lm_head.weight":
      # print(
      #     f"loading lm_head weight, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )
      loaded_weight = torch.nn.functional.normalize(loaded_weight)
      # print(
      #     f"after normalization, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )

Oh yes, I loaded the finetuned weights by my own. I trained it using deepspeed+lora, and finally merge the adapter into the origin model. So maybe it will change the data type. The official weights dosen't raise that error.

@JaheimLee @garyfanhku I also encountered this problem, what is the solution? Can you guys provide me with some ideas? Thank you so much!

@garyfanhku
Copy link
Author

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

Interesting, maybe we load weights differently? Anyways it's meant for printing out the norms before & after norm head operation for a quick visual inspection, one can simply comment out these two print statements.
I've already done so in my latest commit. You can try and re-fetch this PR.

if name == "lm_head.weight":
      # print(
      #     f"loading lm_head weight, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )
      loaded_weight = torch.nn.functional.normalize(loaded_weight)
      # print(
      #     f"after normalization, norm: {loaded_weight.norm(2.0, 1, True).clamp_min(1e-12)}, shape: {loaded_weight.size()}"
      # )

Oh yes, I loaded the finetuned weights by my own. I trained it using deepspeed+lora, and finally merge the adapter into the origin model. So maybe it will change the data type. The official weights dosen't raise that error.

@JaheimLee @garyfanhku I also encountered this problem, what is the solution? Can you guys provide me with some ideas? Thank you so much!

@jugglq Please make sure your local repo is up to date. Or check if the print statements mentioned above were commented out.

@jugglq
Copy link

jugglq commented Oct 26, 2023

I also loaded the finetuned weights by lora like @JaheimLee. Finally found a solution and it worked!

loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight = loaded_weight.to('cuda')

When norm fp16 weight, it will raise RuntimeError: "clamp_min_scalar_cpu" not implemented for 'Half'. Maybe you should convert it to torch.float32 first.

I noticed that the function is ended with _cpu. Some functions may not work on cpu with fp16 in pytorch. Check if the loaded_weight is torch.float32 or on GPU.

@JaheimLee @garyfanhku

@exceedzhang
Copy link

I found baichuan2 not woking well on transformers==4.34.0!
image

@yinjuxin
Copy link

yinjuxin commented Nov 1, 2023

I found baichuan2 not woking well on transformers==4.34.0! image

I also have a same problem, have you solved the problem?

@zhudianGG
Copy link

@exceedzhang @yinjuxin Hi, have you checked this solution(#1403 ), I faced similar situation when tried to load baichuan1, good luck!

@garyfanhku
Copy link
Author

I found baichuan2 not woking well on transformers==4.34.0! image

It's an issue related to Baichuan/InternLM. I checked that vLLM bumped requirements for transformers to 4.34 to accommodate for Mistral. If that doesn't matter to you, I would suggest downgrading vLLM.

Check this for more info baichuan-inc/Baichuan2#226

@zhudianGG
Copy link

@garyfanhku Hi gary, I noticed that they have made some changes(like they deleted tensor_parallel dir and etc...) since this commit:ba0bfd4 , so you'd better adjust your code for the new version vllm so they can merge them, I guess?

@garyfanhku
Copy link
Author

@garyfanhku Hi gary, I noticed that they have made some changes(like they deleted tensor_parallel dir and etc...) since this commit:ba0bfd4 , so you'd better adjust your code for the new version vllm so they can merge them, I guess?

It appears modifying baichuan2.py following the changes in

from vllm.model_executor.parallel_utils.parallel_state import (
would suffice. Workload seems manageable. Would you be interested in forking this PR and make it work with the latest vLLM?

@zhudianGG
Copy link

@garyfanhku Hi gary, I noticed that they have made some changes(like they deleted tensor_parallel dir and etc...) since this commit:ba0bfd4 , so you'd better adjust your code for the new version vllm so they can merge them, I guess?

It appears modifying baichuan2.py following the changes in

from vllm.model_executor.parallel_utils.parallel_state import (

would suffice. Workload seems manageable. Would you be interested in forking this PR and make it work with the latest vLLM?

Sure, I would be glad to offer help~

@PsLink
Copy link

PsLink commented Nov 9, 2023

I try to rebase the code to the latest version by modifying baichuan2.py, and it currently works for baichuan2 models.
can refer to:
https://github.com/PsLink/vllm

@chenxu2048
Copy link
Contributor

Baichuan2 repo provides a script, which can convert baichuan2 weights into baichuan1.

import torch
import os
ori_model_dir = 'your Baichuan 2 model directory'
# To avoid overwriting the original model, it's best to save the converted model to another directory before replacing it
new_model_dir = 'your normalized lm_head weight Baichuan 2 model directory'
model = torch.load(os.path.join(ori_model_dir, 'pytorch_model.bin'))
lm_head_w = model['lm_head.weight']
lm_head_w = torch.nn.functional.normalize(lm_head_w)
model['lm_head.weight'] = lm_head_w
torch.save(model, os.path.join(new_model_dir, 'pytorch_model.bin'))

@WoosukKwon
Copy link
Collaborator

Closing the PR as the model is supported by vLLM.

@WoosukKwon WoosukKwon closed this Mar 13, 2024
@zhouzhou0322
Copy link

@WoosukKwon Baichuan2 is still not supported. The support model is Baichuan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.