Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[ Misc ] Rs/compressed tensors cleanup (vllm-project#5432)
Browse files Browse the repository at this point in the history
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
  • Loading branch information
3 people committed Jun 23, 2024
1 parent 34467ee commit deee747
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_scaled_act_names(self) -> List[str]:
return []

def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16]
return [torch.float16, torch.bfloat16]

# Need to figure it out
def get_min_capability(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": pack_factor
"pack_factor": pack_factor,
"weight_loader": weight_loader
})
set_weight_attrs(weight, {"weight_loader": weight_loader})

layer.register_parameter("weight_packed", weight)

weight_scale = Parameter(
Expand All @@ -79,11 +78,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
requires_grad=False,
)

set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(weight_scale, {
"input_dim": weight_scale_dim,
"output_dim": 0
})
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)

# A 2D array defining the original shape of the weights
Expand All @@ -92,7 +92,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
requires_grad=False)

layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def create_weights(self, layer: torch.nn.Module,
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1

weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
Expand All @@ -61,21 +58,22 @@ def create_weights(self, layer: torch.nn.Module,
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
})

layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,16 @@ def create_weights(self, layer: torch.nn.Module,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

# TODO: remove zero_point parameters once the configs given remove them

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1

input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)
input_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
Expand All @@ -72,11 +66,6 @@ def create_weights(self, layer: torch.nn.Module,
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
Expand All @@ -85,11 +74,6 @@ def create_weights(self, layer: torch.nn.Module,
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True
})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down

0 comments on commit deee747

Please sign in to comment.