forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Quantize vit_b_16 tutorial - Part 1 (pytorch#60)
- Loading branch information
Showing
9 changed files
with
4,214 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from . import dtypes | ||
from .quantization.quant_api import apply_dynamic_quant | ||
from .quantization.quant_api import apply_weight_only_int8_quant | ||
|
||
__all__ = [ | ||
"dtypes" | ||
"dtypes", | ||
"apply_dynamic_quant", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
|
||
# Run bfloat16 version | ||
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b.py | ||
|
||
# Run dynamic quantized version | ||
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b_quant.py | ||
|
||
# Store the output code for further inspection | ||
echo "bfloat16 generated code lives in:" | ||
TORCH_LOGS='output_code' python run_vit_b.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}' | ||
echo "quantization generated code lives in:" | ||
TORCH_LOGS='output_code' python run_vit_b_quant.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
import torchvision.models.vision_transformer as models | ||
|
||
# Load Vision Transformer model | ||
model = models.vit_b_16(pretrained=True) | ||
|
||
# Set the model to evaluation mode | ||
model.eval().cuda().to(torch.bfloat16) | ||
|
||
# Input tensor (batch_size, channels, height, width) | ||
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') | ||
|
||
model = torch.compile(model, mode='max-autotune') | ||
|
||
def benchmark_model(model, num_runs, input_tensor): | ||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
# benchmark | ||
for _ in range(num_runs): | ||
with torch.autograd.profiler.record_function("timed region"): | ||
model(input_tensor) | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / num_runs | ||
|
||
def profiler_runner(path, fn, *args, **kwargs): | ||
with torch.profiler.profile( | ||
activities=[torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA], | ||
record_shapes=True) as prof: | ||
result = fn(*args, **kwargs) | ||
prof.export_chrome_trace(path) | ||
return result | ||
|
||
# Must run with no_grad when optimizing for inference | ||
with torch.no_grad(): | ||
# warmup | ||
benchmark_model(model, 5, input_tensor) | ||
# benchmark | ||
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds") | ||
# Create a trace | ||
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import torchao | ||
import torchvision.models.vision_transformer as models | ||
|
||
# Load Vision Transformer model | ||
model = models.vit_b_16(pretrained=True) | ||
|
||
# Set the model to evaluation mode | ||
model.eval().cuda().to(torch.bfloat16) | ||
|
||
# Input tensor (batch_size, channels, height, width) | ||
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') | ||
|
||
## Quantization code - start | ||
torchao.apply_dynamic_quant(model) | ||
from torch._inductor import config as inductorconfig | ||
inductorconfig.force_fuse_int_mm_with_mul = True | ||
## Quantization code - end | ||
|
||
model = torch.compile(model, mode='max-autotune') | ||
|
||
def benchmark_model(model, num_runs, input_tensor): | ||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
# benchmark | ||
for _ in range(num_runs): | ||
with torch.autograd.profiler.record_function("timed region"): | ||
model(input_tensor) | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / num_runs | ||
|
||
def profiler_runner(path, fn, *args, **kwargs): | ||
with torch.profiler.profile( | ||
activities=[torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA], | ||
record_shapes=True) as prof: | ||
result = fn(*args, **kwargs) | ||
prof.export_chrome_trace(path) | ||
return result | ||
|
||
# Must run with no_grad when optimizing for inference | ||
with torch.no_grad(): | ||
# warmup | ||
benchmark_model(model, 5, input_tensor) | ||
# benchmark | ||
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds") | ||
# Create a trace | ||
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor) |