Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int_scaled_mm on CPU #121

Merged
merged 4 commits into from
Apr 5, 2024
Merged

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Apr 4, 2024

Description
int_scaled_mm is supported on CUDA only now. This PRs adds support for CPU.
The op is implemented by torch._int_mm, whose CPU version has been added to PyTorch recently by pytorch/pytorch#121792.
With this patch, SmoothQuant can go with int_scaled_mm on CPU with Inductor.
Example code:

import torch
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
# convert linear modules to smoothquant
# linear module in calibration mode
swap_linear_with_smooth_fq_linear(model)
model.train()
# Calibrate the model
for data in dataset:
    model(*data)
# set it to inference mode
smooth_fq_linear_to_inference(model.eval())
with torch.no_grad():
    optimized_model = torch.compile(model)
    _ = optimized_model(*example_inputs)
    _ = optimized_model(*example_inputs)

Run with TORCHAO_AUTOTUNER_ENABLE=1 and the following is found in the generated code:

auto buf3 = op_torchao_int_scaled_matmul_.call(reinterpret_tensor(buf1, {16L, 1024L}, {1024L, 1L}, 0L), _frozen_param3, reinterpret_tensor(buf2, {16L, 1024L}, {1L, 0L}, 0L));

Test plan
python test/kernel/test_autotuner.py -k test_int_scaled_mm

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 4, 2024
@Xia-Weiwen
Copy link
Collaborator Author

CC @jgong5 @leslie-fang-intel

@cpuhrsch cpuhrsch merged commit fc5d2c8 into pytorch:main Apr 5, 2024
7 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
@maktukmak
Copy link

maktukmak commented Sep 19, 2024

@Xia-Weiwen , @cpuhrsch , can you also add _scaled_mm for fp8 matmul? It would be very useful for FP quantization methods and mixed precision training. Currently, torch.matmul runs with FP8 inputs on the CPU but the result overflows. Probably, the accumulation dtype is FP8. _scaled_mm can solve this problem.

@Xia-Weiwen Xia-Weiwen deleted the cpu_int_scaled_mm branch September 20, 2024 00:44
@Xia-Weiwen
Copy link
Collaborator Author

@Xia-Weiwen , @cpuhrsch , can you also add _scaled_mm for fp8 matmul? It would be very useful for FP quantization methods and mixed precision training. Currently, torch.matmul runs with FP8 inputs on the CPU but the result overflows. Probably, the accumulation dtype is FP8. _scaled_mm can solve this problem.

Hi @jgong5 @yanbing-j Could you please comment about FP8? Thanks.

@yanbing-j
Copy link
Contributor

@Xia-Weiwen , @cpuhrsch , can you also add _scaled_mm for fp8 matmul? It would be very useful for FP quantization methods and mixed precision training. Currently, torch.matmul runs with FP8 inputs on the CPU but the result overflows. Probably, the accumulation dtype is FP8. _scaled_mm can solve this problem.

Hi @jgong5 @yanbing-j Could you please comment about FP8? Thanks.

At present, we are preparing to add CPU support of _scaled_mm using fp8 matmul. Hope this feature can be addressed in the near future.

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Update README.md

Update README.md (pytorch#118)

Update README.md

Update README.md (pytorch#121)

Update REAME based on pytorch#107
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants