Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 #5874

Merged
merged 22 commits into from
Aug 7, 2024

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Jun 26, 2024

Summary

  • Introduce a series of new parameters to handle the different weight loading cases for parameters loaded to linear layers.
  • Introduces:
  1. BasevLLMParameter
  2. ModelWeightParameter
  3. GroupQuantScaleParameter
  4. ChannelQuantScaleParameter
  5. PerTensorScaleParameter
  6. PackedvLLMParameter
  • Each of these parameters handle the weight-loading logic specific to the different LinearBase classes
  • Significantly cleans up the weight_loader method in each of the LinearBase classes
  • For now, changes are only made to compressed-tensors quantization configs by adding a weight_loader_v2 method to each of the LinearBase classes. All other quantization configurations are still using the original weight loader, as part of the scope of this PR

FOLLOW UP:

  • convert other integrations to use this framework once design is approved

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving comments down

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Jul 1, 2024

This is much better.

I still think we have too much tied logic between linear.py and vLLMParameter as we still branching on booleans in vLLMParameter in linear.py

I think the following two changes would make a better interface.

Remove use_col_loading, use_row_loading, use_col_shard_indexer

We currently have a series of booleans in vLLMParameter (use_col_loading, use_row_loading, and use_col_shard_split). We then do if/else on these inside linear.py. This is still pretty confusing to follow.

I think that instead of using these booleans, we should just create 4 separate classes

  • WeightParameter
  • GroupedScaleParameter
  • ChannelwiseScaleParameter
  • PerTensorScaleParameter

These have the following:

  • The weights have use_col_loading=True, use_row_loading=True
  • The grouped scales have use_col_loading=True, use_row_loading=True
  • The channelwise scales have use_col_loading=True, use_row_loading=False
  • The per-tensor scales have use_col_shard_split=True

This will make it more explicit about what case we are in and why we are in each state, rather than implicitly via the booleans we have now.

Move _default_loading into vLLMParameter classes

Function signatures would look like:

class vLLMParameter:
    def load_merged_column_parallel_linear(layer: MergedColumnParallelLinear, loaded_weight: torch.tensors):
          pass
    def load_row_parallel_linear(layer: RowParallelLinear, loaded_weight: torch.tensor):
         pass
    ....

Then, with this update, in weight_loaderv2 instead of:

if param.use_column_loading:
            param_data, loaded_weight = self._default_loading(
                param=param,
                param_data=param_data,
                loaded_weight=loaded_weight,
                loaded_shard_id=loaded_shard_id)
elif param.use_metadata_loading:  # What case is this?
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
elif param.use_col_shard_split:
      param_data, loaded_weight = param.col_shard_splitter(
            param_data=param_data,
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id)

It would just be:

param_data, loaded_weight = param.loaded_merged_column_parallel_linear(
       layer=self,
       loaded_weight=loaded_weight)

This would better encapsulate the logic and simplify the linear.py file

@dsikka dsikka changed the title Refactor weight loading [Misc] Refactor linear layer weight loading Jul 2, 2024
@dsikka dsikka changed the title [Misc] Refactor linear layer weight loading [Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 Jul 2, 2024
@dsikka dsikka marked this pull request as ready for review July 2, 2024 16:31
@robertgshaw2-neuralmagic
Copy link
Collaborator

Couple nits but LGTM

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic removed the ready ONLY add when PR is ready to merge/full CI is needed label Jul 30, 2024
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 1, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review. Overall LGTM so approve to unblock this PR and follow-up tasks.


# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
# TODO: update create_xxx_parameter functions to return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still a TODO?

Copy link
Contributor Author

@dsikka dsikka Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We're not using the create_xxx_parameter methods here as they are used in places outside of compressed_tensors (e.g fp8). As a follow-up, once we've updated other quantization methods to use these new parameters, we can update the create_xx_parameter functions to return the vLLMParameters. They currently return torch.nn.parameters

channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
channelwise = self.group_size == -1
group_size = self.group_size if self.group_size != -1 else input_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Contributor Author

@dsikka dsikka Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the second condition is just clearer as to what the group_size is and why

@@ -230,14 +225,16 @@ def _get_scheme_from_parts(
group_size=weight_quant.group_size)

# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. General follow-up on the state of these conditions

if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changing? Won't be always have input_quant if is_activation_quantization_format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an extra check the activation config details aren't None/parsed correctly.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Make sure to unblock the multi-gpu A100 model correctness tests. Nice job!

auto-merge was automatically disabled August 5, 2024 22:32

Head branch was pushed to by a user without write access

@simon-mo simon-mo merged commit 0f7052b into vllm-project:main Aug 7, 2024
64 of 68 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants