Skip to content

Commit

Permalink
enable int8 LSTM on latest cpu-device (#692)
Browse files Browse the repository at this point in the history
* int8 lstm graph rewrite

* integrate oneDNN int8 lstm

* use scale and zp of input to be those of output for lstm

* add UT for int8 lstm

* rename var to maybe_quantized_lstm

* add assertion for input scalar type

* only get input scalar type once

* add doxygen spec for pack_qlstm_weight

* add doxygen spec for quantized_lstm

* use inline utils function to get scale and zero point of input and weight
  • Loading branch information
chunyuan-w authored Apr 19, 2022
1 parent c28e621 commit 2bf8dba
Show file tree
Hide file tree
Showing 12 changed files with 597 additions and 67 deletions.
326 changes: 280 additions & 46 deletions intel_extension_for_pytorch/csrc/aten/cpu/RNN.cpp

Large diffs are not rendered by default.

30 changes: 15 additions & 15 deletions intel_extension_for_pytorch/csrc/aten/cpu/RNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class IPEXLSTMOp : public torch::autograd::Function<IPEXLSTMOp> {
bool has_biases,
bool bidirectional,
bool batch_first,
bool train);
bool train,
double scale,
int64_t zp,
int64_t dtype);
static std::vector<at::Tensor> forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
Expand All @@ -58,7 +61,10 @@ class IPEXLSTMOp : public torch::autograd::Function<IPEXLSTMOp> {
bool has_biases,
bool bidirectional,
bool batch_first,
bool train);
bool train,
double scale,
int64_t zp,
int64_t dtype);

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
Expand All @@ -81,7 +87,10 @@ std::vector<at::Tensor> ipex_lstm_layer(
bool has_biases,
bool bidirectional,
bool batch_first,
bool train);
bool train,
double scale,
int64_t zp,
int64_t dtype);

std::vector<at::Tensor> ipex_lstm_layer_backward(
const at::Tensor& input,
Expand Down Expand Up @@ -124,18 +133,9 @@ std::vector<at::Tensor> ipex_lstm_layer_forward(
bool has_biases,
bool bidirectional,
bool batch_first,
bool train);

static std::tuple<at::Tensor, at::Tensor, at::Tensor> ipex_lstm(
const at::Tensor& input,
std::vector<at::Tensor> hx,
std::vector<at::Tensor> params,
bool has_biases,
int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional,
bool batch_first);

double scale,
int64_t zp,
int64_t dtype);
} // namespace cpu
} // namespace torch_ipex
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ ideep::tensor itensor_view_from_dense(
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(
tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"itensor_view_from_dense expects float or bfloat16 tensor input");
tensor.scalar_type() == at::ScalarType::BFloat16 ||
tensor.scalar_type() == at::ScalarType::QInt8 ||
tensor.scalar_type() == at::ScalarType::QUInt8,
"itensor_view_from_dense expects float, bfloat16 or int8 tensor input");
return {desc, tensor.data_ptr()};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
tensor& dst_iter_c,
const bool reverse = false,
const prop_kind aprop = prop_kind::forward_inference,
const float scale = -1.,
const int32_t zp = -1,
const int weights_scale_mask = -1,
const std::vector<float>& weights_scales = scale_t(),
const engine& aengine = engine::cpu_engine()) {
auto direction = reverse ? rnn_direction::unidirectional_right2left
: rnn_direction::unidirectional_left2right;
Expand All @@ -30,13 +34,21 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
auto weights_layer_desc = weights_layer.get_desc().to_format_any();
auto weights_iter_desc = weights_iter.get_desc().to_format_any();

attr_t op_attr;
if (src_layer.get_data_type() == data_type::u8) {
weights_layer_desc = weights_layer_desc.to_type(data_type::s8);
weights_iter_desc = weights_iter_desc.to_type(data_type::s8);

op_attr.set_rnn_data_qparams(scale, zp);
op_attr.set_rnn_weights_qparams(weights_scale_mask, weights_scales);
}

auto bias_desc = bias.get_desc();
auto dst_layer_desc = dst_layer.get_desc();
auto dst_iter_desc = dst_iter.get_desc();
auto dst_iter_c_desc = dst_iter_c.get_desc();

// Use user mode scratchpad
auto op_attr = dnnl::primitive_attr();
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

auto pd = primitive_desc(
Expand All @@ -55,9 +67,9 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
aengine);

auto expected_weights_layer =
weights_layer.reorder_if_differ_in(pd.weights_layer_desc());
weights_layer.reorder_if_differ_in(pd.weights_layer_desc(), op_attr);
auto expected_weights_iter =
weights_iter.reorder_if_differ_in(pd.weights_iter_desc());
weights_iter.reorder_if_differ_in(pd.weights_iter_desc(), op_attr);
tensor scratchpad(pd.scratchpad_desc());

super(pd).execute(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <string>
#include "csrc/jit/cpu/passes/graph_rewrite.h"
Expand Down Expand Up @@ -70,8 +71,11 @@ void IpexQuantFusion(std::shared_ptr<Graph>& graph) {
rewriter.RegisterRewritePattern(info.pattern, info.replacement);
rewriter.runOnGraph(graph, info.filters);
}
GRAPH_DUMP("Before IpexQuantFusion", graph);
graph_rewrite::replaceEmbeddingBagWithQEmbeddingBag(graph);
graph_rewrite::replaceInteractionWithQInteraction(graph);
graph_rewrite::replaceLstmWithQLstm(graph);
GRAPH_DUMP("After IpexQuantFusion", graph);
}

} // namespace jit
Expand Down
81 changes: 81 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/kernels/RNN.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#pragma once

#include <ATen/Tensor.h>

#include <c10/core/Scalar.h>
#include <torch/csrc/jit/runtime/custom_operator.h>

#include "csrc/cpu/ideep/ideep.hpp"

namespace torch_ipex {
namespace cpu {

//! function: quantized_lstm
/*!
*
* Compute a quantized LSTM for INT8 input, INT8 weight and FP32 initial hidden
and cell states which
* returns INT8 ouput along with FP32 final hidden and cell states.
* \param input: INT8 tensor of shape :math:`(L, N, H_{in})` when
``batch_first=False`` or
* :math:`(N, L, H_{in})` when ``batch_first=True`` containing the
features of
* the input sequence.
* \param hx: list of FP32 initial hidden state and cell state:
* hx[0]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
containing the initial hidden
* state for the input sequence batch .
* hx[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
containing the initial cell
* state for the input sequence batch .
* \param weights: List of INT8 weights and FP32 biases.
* \param has_biases: If ``False``, then the layer does not use bias weights
`b_ih` and `b_hh`.
* \param num_layers: the number of layers of LSTM.
* \param dropout_p: If non-zero, introduces a `Dropout` layer on the outputs of
each RNN layer except the last layer, with dropout probability equal to
:attr:`dropout` when the model is in training state.
* \param train: whether the model is in training state.
* \param bidirectional: If ``True``, becomes a bidirectional LSTM.
* \param batch_first: If ``True``, then the input and output tensors are
provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note
that this does not apply to hidden or cell states.
* \param scale: the calibration scale of the output in double.
* \param zp: the calibration zero point of the output in int64_t.
* \param dtype: the calibration data type of the output.
* \return: tuple of output tensors:
* output[0]: INT8 tensor of shape :math:`(L, N, D * H_{out})` when
``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True``
containing the output features
`(h_t)` from the last layer of the RNN, for each `t`.
* output[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
containing the final hidden state for each element in the batch.
* output[2]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
containing the final cell state for each element in the batch.
where:
.. math::
\begin{aligned}
N ={} & \text{batch size} \\
L ={} & \text{sequence length} \\
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
H_{in} ={} & \text{input\_size} \\
H_{out} ={} & \text{hidden\_size}
\end{aligned}
*/
std::tuple<at::Tensor, at::Tensor, at::Tensor> quantized_lstm(
const at::Tensor& input,
c10::List<at::Tensor> hx,
c10::List<at::Tensor> weights,
bool has_biases,
int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional,
bool batch_first,
double scale,
int64_t zp,
int64_t dtype);

} // namespace cpu
} // namespace torch_ipex
72 changes: 72 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,78 @@ void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph) {
}
}

void replaceLstmWithQLstm(std::shared_ptr<Graph>& graph) {
std::vector<std::string> patterns;
std::vector<std::string> replacements;

for (auto* n : graph->block()->nodes()) {
if (n->kind() == aten::lstm) {
std::string weight_pattern = "";
std::vector<std::string> ListConstruct;
std::vector<std::string> header;

size_t id = 0;
auto weights_ListConstructNode = n->input(2)->node();

bool maybe_quantized_lstm = std::any_of(
weights_ListConstructNode->inputs().begin(),
weights_ListConstructNode->inputs().end(),
[](auto& v) {
return v->node()->kind() == Symbol::aten("dequantize");
});

if (!maybe_quantized_lstm)
return;

for (auto input : weights_ListConstructNode->inputs()) {
if (input->node()->kind() == Symbol::aten("dequantize")) {
std::string dequant = "%dq_out_" + std::to_string(id) +
" : Tensor = aten::dequantize(" + "%dq_in_" + std::to_string(id) +
")";
weight_pattern.append(dequant);

header.push_back("%dq_in_" + std::to_string(id));
ListConstruct.push_back("%dq_out_" + std::to_string(id));
} else {
header.push_back("%bias_in_" + std::to_string(id));
ListConstruct.push_back("%bias_in_" + std::to_string(id));
}
++id;
}

std::string complete_header =
"graph(%quantized_input, %h, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist, %scale, %zp, %dtype," +
c10::Join(", ", header) + R"(
): )";
std::string complete_LC = "%weights = prim::ListConstruct(" +
c10::Join(", ", ListConstruct) + ")";

std::string QLstmPattern = complete_header + R"(
%input : Tensor = aten::dequantize(%quantized_input) )" +
weight_pattern + complete_LC + R"(
%output, %hy, %cy = aten::lstm(%input, %h, %weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist)
%quantized_output = aten::quantize_per_tensor(%output, %scale, %zp, %dtype)
return (%quantized_output, %hy, %cy) )";

std::string QLstmReplacement = complete_header + R"(
%quantized_weights : Tensor[] = prim::ListConstruct( )" +
c10::Join(", ", header) + R"(
)
%quantized_output, %hy, %cy = ipex::quantized_lstm(%quantized_input, %h, %quantized_weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist, %scale, %zp, %dtype)
return (%quantized_output, %hy, %cy) )";

patterns.push_back(QLstmPattern);
replacements.push_back(QLstmReplacement);
}
}

SubgraphRewriter rewriter;
for (size_t i = 0; i < patterns.size(); i++) {
rewriter.RegisterRewritePattern(patterns[i], replacements[i]);
rewriter.runOnGraph(graph);
}
}

void fuseBmmAdd(std::shared_ptr<Graph>& graph) {
std::array<std::string, 2> add_operators = {"add", "add_"};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph);
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph>& graph);
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph);
void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph);
void replaceLstmWithQLstm(std::shared_ptr<Graph>& graph);

void replaceFrozenIPEXConvWithAtenConv(std::shared_ptr<Graph>& graph);
void insertPrePackedConvOp(std::shared_ptr<Graph>& graph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "csrc/jit/cpu/kernels/MaxPool2D.h"
#include "csrc/jit/cpu/kernels/Mha.h"
#include "csrc/jit/cpu/kernels/OpContext.h"
#include "csrc/jit/cpu/kernels/RNN.h"
#include "csrc/jit/cpu/kernels/Shuffle.h"
#include "csrc/jit/cpu/kernels/Softmax.h"

Expand Down Expand Up @@ -767,6 +768,33 @@ RegisterOperators op({
},
aliasAnalysisFromSchema()),

Operator(
"ipex::quantized_lstm(Tensor quantized_input, Tensor[] hx, Tensor [] quantized_weights, bool has_biases, int num_layers, float dropout_p, bool train, bool bidirectional, bool batch_first, float scale, int zp, int dtype) -> (Tensor, Tensor, Tensor)",
[](const Node* node) -> Operation {
return [](Stack* stack) {
auto result = quantized_lstm(
(std::move(peek(stack, 0, 12))).toTensor(),
(std::move(peek(stack, 1, 12))).toTensorList(),
(std::move(peek(stack, 2, 12))).toTensorList(),
(std::move(peek(stack, 3, 12))).toBool(),
(std::move(peek(stack, 4, 12))).toInt(),
(std::move(peek(stack, 5, 12))).toDouble(),
(std::move(peek(stack, 6, 12))).toBool(),
(std::move(peek(stack, 7, 12))).toBool(),
(std::move(peek(stack, 8, 12))).toBool(),
(std::move(peek(stack, 9, 12))).toDouble(),
(std::move(peek(stack, 10, 12))).toInt(),
(std::move(peek(stack, 11, 12))).toInt());
drop(stack, 12);

pack(stack, std::move(std::get<0>(result)));
pack(stack, std::move(std::get<1>(result)));
pack(stack, std::move(std::get<2>(result)));
return 0;
};
},
aliasAnalysisFromSchema()),

Operator(
"ipex::shuffle_2d("
" Tensor input,"
Expand Down
3 changes: 2 additions & 1 deletion intel_extension_for_pytorch/csrc/quantization/AutoCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
auto w = at::cat({_params[i], _params[i + 1]}, 1);
weights.push_back(w);
}
calibrate({input}, weights, {output}, "lstm", op_id, OP_TYPE_DEFAULT);
// oneDNN LSTM: input and output share the same scale and zero_point
calibrate({input}, weights, {input}, "lstm", op_id, OP_TYPE_DEFAULT);
return std::make_tuple(output, hy, cy);
}
params p = get_params(op_id);
Expand Down
39 changes: 39 additions & 0 deletions intel_extension_for_pytorch/csrc/quantization/utils/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <ATen/ATen.h>

namespace torch_ipex {
namespace int8 {
namespace utils {

inline std::tuple<double, int64_t> get_mkldnn_input_scale_zp(
const at::Tensor& input) {
TORCH_CHECK(
input.qscheme() == c10::QScheme::PER_TENSOR_AFFINE,
"should use per_tensor_affine quantization for input of LSTM");

double scale = input.q_scale();

// PyTorch scale: (max - min) / (qmax - qmin)
// oneDNN scale: (qmax - qmin) / (max - min)
double mkldnn_scale = 1. / scale;

int64_t zp = input.q_zero_point();
return std::make_tuple(mkldnn_scale, zp);
}

inline at::Tensor get_weight_scale_tensor(const at::Tensor& weight) {
TORCH_CHECK(
weight.qscheme() == c10::QScheme::PER_CHANNEL_AFFINE,
"should use per_channel_affine quantization for weight of LSTM");
at::Tensor weight_scales_tensor = weight.q_per_channel_scales();
TORCH_CHECK(
weight_scales_tensor.dim() == 1,
"expect weight_scales tensor to be 1d, got dim = ",
weight_scales_tensor.dim());
return weight_scales_tensor;
}

} // namespace utils
} // namespace int8
} // namespace torch_ipex
Loading

0 comments on commit 2bf8dba

Please sign in to comment.