-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable int8 LSTM on latest cpu-device (#692)
* int8 lstm graph rewrite * integrate oneDNN int8 lstm * use scale and zp of input to be those of output for lstm * add UT for int8 lstm * rename var to maybe_quantized_lstm * add assertion for input scalar type * only get input scalar type once * add doxygen spec for pack_qlstm_weight * add doxygen spec for quantized_lstm * use inline utils function to get scale and zero point of input and weight
- Loading branch information
1 parent
c28e621
commit 2bf8dba
Showing
12 changed files
with
597 additions
and
67 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#pragma once | ||
|
||
#include <ATen/Tensor.h> | ||
|
||
#include <c10/core/Scalar.h> | ||
#include <torch/csrc/jit/runtime/custom_operator.h> | ||
|
||
#include "csrc/cpu/ideep/ideep.hpp" | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
//! function: quantized_lstm | ||
/*! | ||
* | ||
* Compute a quantized LSTM for INT8 input, INT8 weight and FP32 initial hidden | ||
and cell states which | ||
* returns INT8 ouput along with FP32 final hidden and cell states. | ||
* \param input: INT8 tensor of shape :math:`(L, N, H_{in})` when | ||
``batch_first=False`` or | ||
* :math:`(N, L, H_{in})` when ``batch_first=True`` containing the | ||
features of | ||
* the input sequence. | ||
* \param hx: list of FP32 initial hidden state and cell state: | ||
* hx[0]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` | ||
containing the initial hidden | ||
* state for the input sequence batch . | ||
* hx[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` | ||
containing the initial cell | ||
* state for the input sequence batch . | ||
* \param weights: List of INT8 weights and FP32 biases. | ||
* \param has_biases: If ``False``, then the layer does not use bias weights | ||
`b_ih` and `b_hh`. | ||
* \param num_layers: the number of layers of LSTM. | ||
* \param dropout_p: If non-zero, introduces a `Dropout` layer on the outputs of | ||
each RNN layer except the last layer, with dropout probability equal to | ||
:attr:`dropout` when the model is in training state. | ||
* \param train: whether the model is in training state. | ||
* \param bidirectional: If ``True``, becomes a bidirectional LSTM. | ||
* \param batch_first: If ``True``, then the input and output tensors are | ||
provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note | ||
that this does not apply to hidden or cell states. | ||
* \param scale: the calibration scale of the output in double. | ||
* \param zp: the calibration zero point of the output in int64_t. | ||
* \param dtype: the calibration data type of the output. | ||
* \return: tuple of output tensors: | ||
* output[0]: INT8 tensor of shape :math:`(L, N, D * H_{out})` when | ||
``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True`` | ||
containing the output features | ||
`(h_t)` from the last layer of the RNN, for each `t`. | ||
* output[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` | ||
containing the final hidden state for each element in the batch. | ||
* output[2]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})` | ||
containing the final cell state for each element in the batch. | ||
where: | ||
.. math:: | ||
\begin{aligned} | ||
N ={} & \text{batch size} \\ | ||
L ={} & \text{sequence length} \\ | ||
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ | ||
H_{in} ={} & \text{input\_size} \\ | ||
H_{out} ={} & \text{hidden\_size} | ||
\end{aligned} | ||
*/ | ||
std::tuple<at::Tensor, at::Tensor, at::Tensor> quantized_lstm( | ||
const at::Tensor& input, | ||
c10::List<at::Tensor> hx, | ||
c10::List<at::Tensor> weights, | ||
bool has_biases, | ||
int64_t num_layers, | ||
double dropout_p, | ||
bool train, | ||
bool bidirectional, | ||
bool batch_first, | ||
double scale, | ||
int64_t zp, | ||
int64_t dtype); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
intel_extension_for_pytorch/csrc/quantization/utils/utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
|
||
namespace torch_ipex { | ||
namespace int8 { | ||
namespace utils { | ||
|
||
inline std::tuple<double, int64_t> get_mkldnn_input_scale_zp( | ||
const at::Tensor& input) { | ||
TORCH_CHECK( | ||
input.qscheme() == c10::QScheme::PER_TENSOR_AFFINE, | ||
"should use per_tensor_affine quantization for input of LSTM"); | ||
|
||
double scale = input.q_scale(); | ||
|
||
// PyTorch scale: (max - min) / (qmax - qmin) | ||
// oneDNN scale: (qmax - qmin) / (max - min) | ||
double mkldnn_scale = 1. / scale; | ||
|
||
int64_t zp = input.q_zero_point(); | ||
return std::make_tuple(mkldnn_scale, zp); | ||
} | ||
|
||
inline at::Tensor get_weight_scale_tensor(const at::Tensor& weight) { | ||
TORCH_CHECK( | ||
weight.qscheme() == c10::QScheme::PER_CHANNEL_AFFINE, | ||
"should use per_channel_affine quantization for weight of LSTM"); | ||
at::Tensor weight_scales_tensor = weight.q_per_channel_scales(); | ||
TORCH_CHECK( | ||
weight_scales_tensor.dim() == 1, | ||
"expect weight_scales tensor to be 1d, got dim = ", | ||
weight_scales_tensor.dim()); | ||
return weight_scales_tensor; | ||
} | ||
|
||
} // namespace utils | ||
} // namespace int8 | ||
} // namespace torch_ipex |
Oops, something went wrong.