-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What is the expected inference steps after I apply torchao in training? #1132
Comments
@goldhuang I agree we need some more documentation here. Our previous inference solution worked correctly here but today it does not, For example the following will error import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
print(m)
from torchao.quantization.quant_api import float8_weight_only, quantize_
quantize_(m, float8_weight_only())
print(m)
with torch.no_grad():
y = m(x)
print(y.dtype)
print(y) |
@drisspg Thanks! Is If I only care about quality, which one is better? (We have other solutions for inference like trt. I only want to check the quality of the results after training with torchao.)
|
Ohh that is an interesting question, you only care about quality and you dont care about weight size? If so you are likely looking for a different solution let me hack something up for you |
Below shows how you can convert back the layers from Float8Linear back to Linear in the original high precision import torch
import torch.nn as nn
from typing import Callable, Optional
from torchao.float8 import convert_to_float8_training
# create model and sample input
m = nn.Sequential(
nn.Linear(32, 64, bias=False),
nn.Linear(64, 64, bias=False),
).bfloat16().cuda()
x = torch.randn(16, 32, device="cuda", dtype=torch.bfloat16)
target = torch.ones(16, 64, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.5)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# toy training loop
for _ in range(1000):
optimizer.zero_grad()
y = m(x)
loss = torch.nn.functional.mse_loss(y, target)
loss.backward()
optimizer.step()
print(m)
### FINISHED TRAINING ###
from torchao.float8.float8_linear import Float8Linear
def swap_linear_layers(
module: nn.Module,
target_module: nn.Module,
swap_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module: Module to modify.
target_module: Replace these modules
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root {target_module} with children: {module}"
)
return swap_func(module)
root_module = module
def post_order_traversal(
module: nn.Module,
cur_fqn: Optional[str] = None,
parent_module: Optional[nn.Module] = None,
):
if cur_fqn is None:
cur_fqn = ""
for child_module_name, child_module in module.named_children():
if cur_fqn == "":
new_fqn = child_module_name
else:
new_fqn = f"{cur_fqn}.{child_module_name}"
post_order_traversal(child_module, new_fqn, module)
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
), f"{target_module} root module should return early: {module}"
new_module = swap_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_module)
post_order_traversal(root_module)
return root_module
def dequantize_float8_training(model: nn.Module) -> nn.Module:
"""
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
"""
def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module
return swap_linear_layers(
model,
Float8Linear,
dequant_func,
)
dequantize_float8_training(m)
print("DEQUANTIZED MODEL".center(80, "-"))
print(m)
with torch.no_grad():
y = m(x)
print(y) |
@drisspg My understanding is that we just use original model and the original weights trained by torchao if we prioritize quality. |
You should be able to use the exisiting quant_api flow whether you trained w/ or w/out FP8. Currently there is a bug we will fix |
quantize_() now supports Float8Linear #1344 |
Hello, I have integrated torchao to my training. But I don't think it's 100% clear what the inference should be like.
Should I use the converted FP8 linear layer to do inference? Is delayed scaling supposed to work in inference?
Or, should I use the original linear layer to do inference?
Thanks a lot in advance if you can help to clarify!
The text was updated successfully, but these errors were encountered: