Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using train_with_template on mistral end up in a model with a loop #3055

Open
christobill opened this issue Feb 16, 2024 · 3 comments
Open

Comments

@christobill
Copy link

I use train_with_template.py with mistralai/Mistral-7B-Instruct-v0.2

torchrun --nproc_per_node=2 --master_port=20001 fastchat/train/train_with_template.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
    --data_path data/dummy_conversation.json \
    --bf16 True \
    --output_dir mistral-7b \
    --num_train_epochs 3 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1200 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess False

Then I run the model:

python3 -m fastchat.serve.cli --model-path mistral-7b-0103/ --num-gpus 2 --debug

I get:

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.29it/s]
MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
[INST]: Hello
[/INST]: How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How

{'conv_template': 'mistral', 'prompt': '[INST] Hello [/INST]', 'outputs': 'How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How can I help you today? How', 'speed (token/s)': 58.84}

There might be a problem with the stop token or something?

@congchan
Copy link
Contributor

congchan commented Feb 17, 2024

Hi, I never seen such wired issue before, will test this case when I find myself available.

I did not familiar with mistral template, if it is the same as Llama, which use </s> as stop token, the script should works, there is a test sample #3006

@congchan
Copy link
Contributor

Hi, I think this is a model specific behavior. Some models need to specifically learn to generate the stop tokens.

  • I tested with Yi-34b, and the model works fine.
  • While tested with qwen1.5, the model failed to generate stop tokens, which is similar to your Mistral models.
    • While explicitly train qwen1.5 to learn to generate the stop token, it succeed to stop when needed.

If you would like to have a try on your Mistral model, which is to change the line

to

                if i < len(turns) - 1:
                    turn = turn + user_turn_separator

@christobill
Copy link
Author

christobill commented Feb 18, 2024

Thanks @congchan for looking into this.

I have the same kind of issue with vicuna-7b, but it is a bit more random:

torchrun --nproc_per_node=2 --master_port=20001 fastchat/train/train_with_template.py \
    --model_name_or_path lmsys/vicuna-7b-v1.5 \
    --data_path data/dummy_conversation.json \
    --bf16 True \
    --output_dir vicuna-7b \
    --num_train_epochs 3 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1200 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True

Then after several runs (the last one is a loop):

# python3 -m fastchat.serve.cli --model-path vicuna-7b/ --num-gpus 2 --debug
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.46it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
USER: Bonjour
ASSISTANT: Bonjour! How can I help you today?

{'conv_template': 'vicuna_v1.1', 'prompt': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Bonjour ASSISTANT:", 'outputs': 'Bonjour! How can I help you today?', 'speed (token/s)': 14.79}

USER: ^Cexit...

# python3 -m fastchat.serve.cli --model-path vicuna-7b/ --num-gpus 2 --debug  --conv-system-msg "You are Vicuna"
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00,  1.97it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
USER: Hello
ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask. hopefully one of my team members will be able to help you out. Goodbye!

{'conv_template': 'vicuna_v1.1', 'prompt': 'You are Vicuna USER: Hello ASSISTANT:', 'outputs': "How can I help you today? If you have any more questions in the future, don't hesitate to ask. hopefully one of my team members will be able to help you out. Goodbye!", 'speed (token/s)': 36.56}

USER: !!regen
regenerating last message...
USER: Hello
ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask. grows.

{'conv_template': 'vicuna_v1.1', 'prompt': 'You are Vicuna USER: Hello ASSISTANT:', 'outputs': "How can I help you today? If you have any more questions in the future, don't hesitate to ask. grows.", 'speed (token/s)': 58.82}

USER: Hello
ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask.

{'conv_template': 'vicuna_v1.1', 'prompt': "You are Vicuna USER: Hello ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask. grows.</s>USER: Hello ASSISTANT:", 'outputs': "How can I help you today? If you have any more questions in the future, don't hesitate to ask.", 'speed (token/s)': 57.91}

USER: How are you
ASSISTANT: My apologies! If you have any more questions in the future, don't hesitate to ask. I'm here to help. grows.

{'conv_template': 'vicuna_v1.1', 'prompt': "You are Vicuna USER: Hello ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask. grows.</s>USER: Hello ASSISTANT: How can I help you today? If you have any more questions in the future, don't hesitate to ask.</s>USER: How are you ASSISTANT:", 'outputs': "My apologies! If you have any more questions in the future, don't hesitate to ask. I'm here to help. grows.", 'speed (token/s)': 59.24}

USER: ^Cexit...

# python3 -m fastchat.serve.cli --model-path vicuna-7b/ --num-gpus 2 --debug  --conv-system-msg "You are Vicuna a chatbot designed to help people"
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.56it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
USER: Hello
ASSISTANT: Hello! How can I help you today? If you have any more questions in the future, don't hesitate to ask. A good percentage of people don't have any more questions after a session with a chatbot like me because they're afraid to ask. If you have any more questions, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future

{'conv_template': 'vicuna_v1.1', 'prompt': 'You are Vicuna a chatbot designed to help people USER: Hello ASSISTANT:', 'outputs': "Hello! How can I help you today? If you have any more questions in the future, don't hesitate to ask. A good percentage of people don't have any more questions after a session with a chatbot like me because they're afraid to ask. If you have any more questions, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future, don't hesitate to ask. If you have any more questions in the future", 'speed (token/s)': 55.03}

I will try with the solution you suggest! 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants