Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the expected inference steps after I apply torchao in training?
 #1132

Closed
goldhuang opened this issue Oct 21, 2024 · 8 comments · Fixed by #1344
Closed

What is the expected inference steps after I apply torchao in training?
 #1132

goldhuang opened this issue Oct 21, 2024 · 8 comments · Fixed by #1344
Labels

Comments

@goldhuang
Copy link

goldhuang commented Oct 21, 2024

Hello, I have integrated torchao to my training. But I don't think it's 100% clear what the inference should be like.

Should I use the converted FP8 linear layer to do inference? Is delayed scaling supposed to work in inference?
Or, should I use the original linear layer to do inference?

Thanks a lot in advance if you can help to clarify!

@supriyar
Copy link
Contributor

cc @vkuzo @drisspg

@drisspg
Copy link
Contributor

drisspg commented Oct 22, 2024

@goldhuang I agree we need some more documentation here. Our previous inference solution worked correctly here but today it does not,

For example the following will error

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()


print(m)

from torchao.quantization.quant_api import float8_weight_only, quantize_


quantize_(m, float8_weight_only())

print(m)

with torch.no_grad():
    y = m(x)
    print(y.dtype)
    print(y)    

@goldhuang
Copy link
Author

@drisspg Thanks!

Is quantize_(m, float8_weight_only()) mainly for inference speed?

If I only care about quality, which one is better? (We have other solutions for inference like trt. I only want to check the quality of the results after training with torchao.)

  1. use converted fp8linear + quantize weights
  2. use original model

@drisspg
Copy link
Contributor

drisspg commented Oct 22, 2024

Ohh that is an interesting question, you only care about quality and you dont care about weight size? If so you are likely looking for a different solution let me hack something up for you

@drisspg
Copy link
Contributor

drisspg commented Oct 22, 2024

Below shows how you can convert back the layers from Float8Linear back to Linear in the original high precision

import torch
import torch.nn as nn
from typing import Callable, Optional
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(32, 64, bias=False),
    nn.Linear(64, 64, bias=False),
).bfloat16().cuda()

x = torch.randn(16, 32, device="cuda", dtype=torch.bfloat16)
target = torch.ones(16, 64, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.5)


# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# toy training loop
for _ in range(1000):
    optimizer.zero_grad()
    y = m(x)
    loss = torch.nn.functional.mse_loss(y, target)
    loss.backward()
    optimizer.step()


print(m)


### FINISHED TRAINING ###

from torchao.float8.float8_linear import Float8Linear

def swap_linear_layers(
    module: nn.Module,
    target_module: nn.Module,
    swap_func: Callable[[nn.Linear], nn.Linear],
    *,
    module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
    """
    Generic function to swap linear layers in a module with a new type of linear layer.

    Note:
        If applied to a root-level nn.Linear, the module will not be modified in place
        and returned instead

    Args:
        module: Module to modify.
        target_module: Replace these modules
        from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
        module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
            that pass the filter function will be swapped. The inputs to the
            filter function are the module instance, and the FQN.

    Returns:
     nn.Module: The modified module with swapped linear layers.
    """
    if isinstance(module, target_module) and (
        module_filter_fn is None or module_filter_fn(module, "")
    ):
        if len(list(module.children())) > 0:
            raise AssertionError(
                f"Does not support a root {target_module} with children: {module}"
            )
        return swap_func(module)

    root_module = module

    def post_order_traversal(
        module: nn.Module,
        cur_fqn: Optional[str] = None,
        parent_module: Optional[nn.Module] = None,
    ):
        if cur_fqn is None:
            cur_fqn = ""

        for child_module_name, child_module in module.named_children():
            if cur_fqn == "":
                new_fqn = child_module_name
            else:
                new_fqn = f"{cur_fqn}.{child_module_name}"

            post_order_traversal(child_module, new_fqn, module)

        if isinstance(module, target_module) and (
            module_filter_fn is None or module_filter_fn(module, cur_fqn)
        ):
            assert (
                parent_module is not None
            ), f"{target_module} root module should return early: {module}"
            new_module = swap_func(module)
            cur_module_name = cur_fqn.split(".")[-1]
            setattr(parent_module, cur_module_name, new_module)

    post_order_traversal(root_module)
    return root_module




def dequantize_float8_training(model: nn.Module) -> nn.Module:
    """
    Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
    """

    def dequant_func(mod: Float8Linear) -> nn.Linear:
        new_module = nn.Linear(mod.in_features, mod.out_features)
        new_module.weight = mod.weight
        new_module.bias = mod.bias
        return new_module

    return swap_linear_layers(
        model,
        Float8Linear,
        dequant_func,
    )

dequantize_float8_training(m)
print("DEQUANTIZED MODEL".center(80, "-"))

print(m)

with torch.no_grad():
    y = m(x)
    print(y)

@goldhuang
Copy link
Author

goldhuang commented Oct 22, 2024

@drisspg My understanding is that we just use original model and the original weights trained by torchao if we prioritize quality.
If we want to use fp8 in inference, we keep the FP8Liner but convert weights to fp8 too during inference time.

@drisspg
Copy link
Contributor

drisspg commented Oct 22, 2024

You should be able to use the exisiting quant_api flow whether you trained w/ or w/out FP8. Currently there is a bug we will fix

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
@jainapurva
Copy link
Contributor

quantize_() now supports Float8Linear #1344

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants