Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[8/x] make single linear profiling script work with Float8 scaling type #299

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -68,6 +70,7 @@ class Experiment:
compiled: bool
use_fast_accum: bool
linear_type: str
scaling_repr: str

# 3 Times since we are calculating forward backward
@property
Expand Down Expand Up @@ -96,10 +99,17 @@ def main(
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
linear_type_filter: Optional[str] = None,
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
name_to_shapes_70b = {
Expand Down Expand Up @@ -134,9 +144,24 @@ def main(
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
)

linear_float8 = get_float8_linear(
linear_type_enum, copy.deepcopy(linear_ref), emulate=False
)
if linear_type == "delayed":
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
scaling_repr = linear_float8.scaling_repr()
else:
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
)
scaling_repr = None

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
else:
Expand All @@ -150,7 +175,10 @@ def main(
if linear_type_enum == LinearType.DELAYED:

def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
if linear_requires_sync(
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

else:
Expand Down Expand Up @@ -197,6 +225,7 @@ def wrapper(*args, **kwargs):
compile,
use_fast_accum=fast_accum,
linear_type=linear_type,
scaling_repr=scaling_repr,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -209,6 +238,7 @@ def wrapper(*args, **kwargs):
"K",
"N",
"linear_type",
"scaling_repr",
"ref_dtype",
"compiled",
"use_fast_accum",
Expand All @@ -228,6 +258,7 @@ def wrapper(*args, **kwargs):
experiment.shape[1],
experiment.shape[2],
experiment.linear_type,
experiment.scaling_repr,
experiment.dtype,
experiment.compiled,
experiment.use_fast_accum,
Expand Down Expand Up @@ -257,6 +288,7 @@ def wrapper(*args, **kwargs):
"name",
"shape",
"linear_type",
"scaling_repr",
"compiled",
"use_fast_accum",
"ref_time_sec",
Expand All @@ -280,15 +312,26 @@ def invoke_main() -> None:
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--linear_type_filter", type=str, required=False)
parser.add_argument("--scaling_type_x", type=str, required=False)
parser.add_argument("--scaling_type_w", type=str, required=False)
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.scaling_type_x is not None:
kwargs["scaling_type_x"] = args.scaling_type_x
if args.scaling_type_w is not None:
kwargs["scaling_type_w"] = args.scaling_type_w
if args.scaling_type_dL_dY is not None:
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
main(
output_path,
args.compile,
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
args.linear_type_filter,
**kwargs,
)


Expand Down
11 changes: 5 additions & 6 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_post_forward()
return y

def extra_repr(self):
# example: in_features=32, out_features=16, bias=True
s = super().extra_repr()
def scaling_repr(self):
# add scaling settings without using too many characters
scaling = f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"

s = f'{s}, scaling="{scaling}"'
# example: in_features=32, out_features=16, bias=True, scaling="x:del,w:del,dldy:dyn"
def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
return s

@classmethod
Expand Down
Loading