Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]Add customized information for models #4132

Merged
merged 6 commits into from
May 1, 2024

Conversation

jeejeelee
Copy link
Contributor

@jeejeelee jeejeelee commented Apr 17, 2024

When I debug the VLLM, the model's print output always bothers me because it lacks details, as shown below:

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=1024, out_features=4096, bias=True)
              (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (multi_modal_projector): LlavaMultiModalProjector(
    (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
    (act): GELU(approximate='none')
    (linear_2): Linear(in_features=4096, out_features=4096, bias=True)
  )
  (language_model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding()
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear()
          (o_proj): RowParallelLinear()
          (rotary_emb): RotaryEmbedding()
          (attn): Attention()
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear()
          (down_proj): RowParallelLinear()
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): ParallelLMHead()
  (logits_processor): LogitsProcessor()
  (sampler): Sampler()
)

By leveraging extra_repr, we can add more details to the model, and can achieve the following print output:

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=1024, out_features=4096, bias=True)
              (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (multi_modal_projector): LlavaMultiModalProjector(
    (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
    (act): GELU(approximate='none')
    (linear_2): Linear(in_features=4096, out_features=4096, bias=True)
  )
  (language_model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=32064, embedding_dim=4096, org_vocab_size=32064, num_embeddings_padded=32064, tp_size=1)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear(in_features=4096, output_features=12288, bias=False, tp_size=1, gather_output=False)
          (o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=4096, base=10000.0, is_neox_style=True)
          (attn): Attention(head_size=128, num_heads=32, num_kv_heads=32, scale=0.08838834764831845)
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=22016, bias=False, tp_size=1, gather_output=False)
          (down_proj): RowParallelLinear(input_features=11008, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
        (post_attention_layernorm): RMSNorm(hidden_size=4096, eps=1e-05)
      )
    )
    (norm): RMSNorm(hidden_size=4096, eps=1e-05)
  )
  (lm_head): ParallelLMHead(num_embeddings=32064, embedding_dim=4096, org_vocab_size=32064, num_embeddings_padded=32064, tp_size=1)
  (logits_processor): LogitsProcessor(vocab_size=32064, org_vocab_size=32064, scale=1.0, logits_as_input=False)
  (sampler): Sampler()

@jeejeelee
Copy link
Contributor Author

@zhuohan123 could I trouble you to review this PR, I am not sure whether this feature is useful. If not, I will close this PR.

Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be pretty useful!

vllm/model_executor/layers/logits_processor.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/rotary_embedding.py Outdated Show resolved Hide resolved
@jeejeelee jeejeelee requested a review from mgoin April 18, 2024 02:40
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you

@jeejeelee
Copy link
Contributor Author

@mgoin It seems that we also need a collaborator to review this PR, is that right?

@jeejeelee
Copy link
Contributor Author

@simon-mo could I trouble you to take a look at this PR, please?

@jeejeelee jeejeelee requested a review from mgoin April 20, 2024 15:46
@jeejeelee
Copy link
Contributor Author

@zhuohan123 @simon-mo Could you please take a look at this PR? If it's not useful, I'll close it. I'm hoping to get some feedback. :bowtie:

@jeejeelee jeejeelee mentioned this pull request May 1, 2024
@youkaichao
Copy link
Member

I think this is a good direction in general. You should add comment to point to the doc https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.extra_repr , this feature is not well-known.

In addition, please merge main into this branch to pass the CI.

@simon-mo simon-mo merged commit d6f4bd7 into vllm-project:main May 1, 2024
17 of 19 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
@jeejeelee jeejeelee deleted the model-print branch May 6, 2024 16:08
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants