Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower quant/dequant torch op to StableHLO #5763

Merged
merged 23 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ http_archive(
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
urls = [
Expand Down
137 changes: 137 additions & 0 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index f74973ae1..8e3f0e06b 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -67,6 +67,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 08d5f49c8..2f9ad1e0b 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -664,6 +664,70 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0;
+ Type storage_type, expressed_type;
+
+ auto scales_attr = backend_config.get("scale");
+ if (scales_attr) {
+ for (auto scale_attr : scales_attr.cast<mlir::ArrayAttr>()) {
+ scales.push_back(scale_attr.cast<mlir::FloatAttr>().getValueAsDouble());
+ }
+ }
+
+ auto zero_points_attr = backend_config.get("zero_point");
+ if (zero_points_attr) {
+ for (auto zero_point_attr : zero_points_attr.cast<mlir::ArrayAttr>()) {
+ zero_points.push_back(zero_point_attr.cast<mlir::IntegerAttr>().getInt());
+ }
+ }
+
+ auto quantization_dimension_attr =
+ backend_config.get("quantization_dimension");
+ if (quantization_dimension_attr) {
+ quantization_dimension =
+ quantization_dimension_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_max_attr = backend_config.get("storage_max");
+ if (storage_max_attr) {
+ storage_max = storage_max_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_min_attr = backend_config.get("storage_min");
+ if (storage_min_attr) {
+ storage_min = storage_min_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_type_attr = backend_config.get("storage_type");
+ if (storage_type_attr) {
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto expressed_type_attr = backend_config.get("expressed_type");
+ if (expressed_type_attr) {
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSigned();
+
+ if (quantization_dimension != -1) {
+ return mlir::quant::UniformQuantizedPerAxisType::get(
+ is_signed, storage_type, expressed_type, scales, zero_points,
+ quantization_dimension, storage_min, storage_max);
+ } else {
+ return mlir::quant::UniformQuantizedType::get(
+ is_signed, storage_type, expressed_type, scales[0], zero_points[0],
+ storage_min, storage_max);
+ }
+}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
@@ -933,6 +997,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_quantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_dequantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(
+ loc, result_type, operands) .getOperation();
+ }
}
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index 9f05992c8..03cf4840d 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
module.getContext()->loadDialect<mlir::arith::ArithDialect>();
module.getContext()->loadDialect<mlir::func::FuncDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
}

namespace {
45 changes: 45 additions & 0 deletions openxla_patches/stablehlo_quant_seralization.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format
// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed.
diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch
new file mode 100644
index 000000000..24e23b67d
--- /dev/null
+++ b/third_party/stablehlo/stablehlo_quant_seralization.patch
@@ -0,0 +1,26 @@
+diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp
+index 07c856db..cd169cae 100644
+--- a/stablehlo/api/PortableApi.cpp
++++ b/stablehlo/api/PortableApi.cpp
+@@ -15,10 +15,13 @@ limitations under the License.
+
+ #include "stablehlo/api/PortableApi.h"
+
++#include <iostream>
+ #include <string>
+
+ #include "mlir/Bytecode/BytecodeWriter.h"
+ #include "mlir/Dialect/Func/IR/FuncOps.h"
++#include "mlir/Dialect/Quant/QuantOps.h"
++#include "mlir/Dialect/Quant/QuantTypes.h"
+ #include "mlir/IR/MLIRContext.h"
+ #include "mlir/Parser/Parser.h"
+ #include "stablehlo/dialect/Serialization.h"
+@@ -33,6 +36,7 @@ void loadSerializationDialects(MLIRContext* context) {
+ context->loadDialect<mlir::func::FuncDialect>();
+ context->loadDialect<mlir::stablehlo::StablehloDialect>();
+ context->loadDialect<mlir::vhlo::VhloDialect>();
++ context->loadDialect<mlir::quant::QuantizationDialect>();
+ }
+ } // namespace
+
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index 9f4494aac..64fa072bb 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -15,5 +15,6 @@ def repo():
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
patch_file = [
"//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove.
+ "//third_party/stablehlo:stablehlo_quant_seralization.patch", # Load quant dialect.
],
)
72 changes: 72 additions & 0 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import torch
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torch._export import capture_pre_autograd_graph
import torchvision
from torch_xla import stablehlo
import torch_xla.core.xla_model as xm
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
import unittest

# Needed to workaround the stablehlo bytecode serialization issue in https://github.com/openxla/stablehlo/issues/1812
os.environ['STABLEHLO_BYTECODE_FROM_PRETTYPRINT'] = '1'


class PT2EExportTest(unittest.TestCase):

def test_per_tensor_qdq(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt)
self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)

def test_per_channel_qdq(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
scale = torch.tensor([3.2, 5.3, 0.1, 10])
zero_point = torch.tensor([1, 2, -1, -2])
x = torch.ops.quantized_decomposed.quantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt)
self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)

def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we use this instead of torch.export?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I saw the export below, but still confuse what this function does to the module.

Copy link
Collaborator Author

@lsy323 lsy323 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the graph is captured for PT2E to further process. PT2E doesn't work with graph captured from torch.exported (just tried locally), it needs to capture the graph in this way.

The export down below is for PyTorch -> StableHLO exporting, our API only works on exported program


# Step 2: Insert observers or fake quantize modules
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)

# Step 3: Quantize the model
m = convert_pt2e(m)

# Trace with torch/xla and export stablehlo
exported = torch.export.export(m, args)
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo_txt = stablehlo_gm.get_stablehlo_text('forward')
# print(stablehlo_txt)
self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt)
self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)
# Save as tf.saved_model
# save_torch_module_as_tf_saved_model(m, args, '/tmp/tf_saved_model/tmp1')


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ ptxla_cc_library(
"nll_loss.cpp",
"nms_op.cpp",
"pooling.cpp",
"quant_util.cpp",
"random.cpp",
"reduction.cpp",
"resize_ops.cpp",
Expand Down Expand Up @@ -88,6 +89,7 @@ ptxla_cc_library(
"nll_loss.h",
"nms_op.h",
"pooling.h",
"quant_util.h",
"random.h",
"reduction.h",
"resize_ops.h",
Expand Down
46 changes: 46 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
return bridge::AtenFromXlaTensor(std::move(result));
}

at::Tensor QuantizeTensor(const at::Tensor& input,
const std::vector<float>& scale_list,
const std::vector<float>& zero_point_list,
int quant_min, int quant_max,
const std::string& dtype, int axis) {
auto result = tensor_methods::quantize_tensor(
bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min,
quant_max, dtype, axis);
return bridge::AtenFromXlaTensor(std::move(result));
}

at::Tensor DequantizeTensor(const at::Tensor& input,
const std::vector<float>& scale_list,
const std::vector<float>& zero_point_list,
int quant_min, int quant_max,
const std::string& dtype, int axis) {
auto result = tensor_methods::dequantize_tensor(
bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min,
quant_max, dtype, axis);
return bridge::AtenFromXlaTensor(std::move(result));
}

std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down Expand Up @@ -1104,6 +1126,30 @@ void InitXlaModuleBindings(py::module m) {
}
return result;
});
m.def("_xla_quantize_tensor",
[](const at::Tensor& input, const std::vector<float>& scale_list,
const std::vector<float>& zero_point_list, int quant_min,
int quant_max, const std::string& dtype, int axis) -> at::Tensor {
at::Tensor result;
{
NoGilSection nogil;
result = QuantizeTensor(input, scale_list, zero_point_list,
quant_min, quant_max, dtype, axis);
}
return result;
});
m.def("_xla_dequantize_tensor",
[](const at::Tensor& input, const std::vector<float>& scale_list,
const std::vector<float>& zero_point_list, int quant_min,
int quant_max, const std::string& dtype, int axis) -> at::Tensor {
at::Tensor result;
{
NoGilSection nogil;
result = DequantizeTensor(input, scale_list, zero_point_list,
quant_min, quant_max, dtype, axis);
}
return result;
});
m.def("_xla_all_to_all",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
case EmitMode::kHloReadable:
return ConsumeValue(runtime::util::GetComputationHloText(computation));
case EmitMode::kStableHloReadable:
return runtime::hloToStablehlo(&computation.proto(),
/* emit_bytecode = */ false);
return hloToStablehlo(&computation.proto(),
/* emit_bytecode = */ false);
case EmitMode::kStableHloBytecode:
return runtime::hloToStablehlo(&computation.proto(),
/* emit_bytecode = */ true);
return hloToStablehlo(&computation.proto(),
/* emit_bytecode = */ true);
}
}

Expand Down
Loading