Skip to content

Commit

Permalink
fix(//cpp/ptq): Tracing model in eval mode wrecks accuracy in Libtorch
Browse files Browse the repository at this point in the history
HACK: WYA tracing without being in eval mode and ignoring the warning,
will follow up with the PyTorch Team and test after script mode support
lands

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 25, 2020
1 parent cd24f26 commit 54a24b3
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 7 deletions.
127 changes: 126 additions & 1 deletion cpp/ptq/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,125 @@
# ptq

## How to create your own PTQ application

Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain and track the activations in FP32 to calibrate a mapping to INT8 that minimizes the information loss between FP32 inference and INT8 inference.

Users writing TensorRT applications are required to setup a calibrator class which will provide sample data to the TensorRT calibrator. With TRTorch we look to leverage existing infrastructure in PyTorch to make implementing calibrators easier.

LibTorch provides a `Dataloader` and `Dataset` API which steamlines preprocessing and batching input data. TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly implement a `torch::Dataset` for your target domain, place it in a Dataloader and create a INT8 Calibrator from it which you can provide to TRTorch to run INT8 Calibration during compliation of your module.

### Code

Here is an example interface of a `torch::Dataset` class for CIFAR10:

```C++
//cpp/ptq/datasets/cifar10.h
#pragma once

#include "torch/data/datasets/base.h"
#include "torch/data/example.h"
#include "torch/types.h"

#include <cstddef>
#include <string>

namespace datasets {
// The CIFAR10 Dataset
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
public:
// The mode in which the dataset is loaded
enum class Mode { kTrain, kTest };

// Loads CIFAR10 from un-tarred file
// Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
// Root path should be the directory that contains the content of tarball
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);

// Returns the pair at index in the dataset
torch::data::Example<> get(size_t index) override;

// The size of the dataset
c10::optional<size_t> size() const override;

// The mode the dataset is in
bool is_train() const noexcept;

// Returns all images stacked into a single tensor
const torch::Tensor& images() const;

// Returns all targets stacked into a single tensor
const torch::Tensor& targets() const;

// Trims the dataset to the first n pairs
CIFAR10&& use_subset(int64_t new_size);


private:
Mode mode_;
torch::Tensor images_, targets_;
};
} // namespace datasets
```
This class's implementation reads from the binary distribution of the CIFAR10 dataset and builds two tensors which hold the images and labels.
Then we select a subset of the dataset to use for calibration, since we don't need the the full dataset for calibration and calibration does take time, then define the preprocessing to apply to the images in the dataset and create a Dataloader from the dataset which will batch the data:
```C++
auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
.use_subset(320)
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465},
{0.2023, 0.1994, 0.2010}))
.map(torch::data::transforms::Stack<>());
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset),
torch::data::DataLoaderOptions().batch_size(32)
.workers(2));
```

Next we create a calibrator from the `calibration_dataloader` using the calibrator factory:

```C++
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);

```

Here we also define a location to write a calibration cache file to which we can use to reuse the calibration data without needing the dataset and whether or not we should use the cache file if it exists. There also exists a `trtorch::ptq::make_int8_cache_calibrator` factory which creates a calibrator that uses the cache only for cases where you may do engine building on a machine that has limited storage (i.e. no space for a dataset) or to have a simpiler deployment application.

The calibrator factories create a calibrator that inherits from a `nvinfer1::IInt8Calibrator` virtual class (`nvinfer1::IInt8EntropyCalibrator2` by default) which defines the calibration algorithm used when calibrating. You can explicitly make the selection of calibration algorithm like this:

```C++
// MinMax Calibrator is geared more towards NLP tasks
auto calibrator = trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, true);
```

Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::ExtraInfo` struct and compiling the module:

```C++
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
/// Configure settings for compilation
auto extra_info = trtorch::ExtraInfo({input_shape});
/// Set operating precision to INT8
extra_info.op_precision = torch::kI8;
/// Use the TensorRT Entropy Calibrator
extra_info.ptq_calibrator = calibrator;
/// Set a larger workspace (you may get better performace from doing so)
extra_info.workspace_size = 1 << 28;

auto trt_mod = trtorch::CompileGraph(mod, extra_info);
```
If you have an existing Calibrator implementation for TensorRT you may directly set the `ptq_calibrator` field with a pointer to your calibrator and it will work as well.
From here not much changes in terms of how to execution works. You are still able to fully use Libtorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`.
## Running the Example Application
This is a short example application that shows how to use TRTorch to perform post-training quantization for a module.
## Prerequisites
1. Download CIFAR10 Dataset Binary version
1. Download CIFAR10 Dataset Binary version ([https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz))
2. Train a network on CIFAR10 (see `training/` for a VGG16 recipie)
3. Export model to torchscript
Expand All @@ -26,6 +141,16 @@ bazel build //cpp/ptq --compilation_mode=dbg
ptq <path-to-module> <path-to-cifar10>
```

## Example Output

```
Accuracy of JIT model on test set: 92.1%
Compiling and quantizing module
Accuracy of quantized model on test set: 91.0044%
Latency of JIT model FP32 (Batch Size 32): 1.73497ms
Latency of quantized model (Batch Size 32): 0.365737ms
```

## Citations

```
Expand Down
11 changes: 5 additions & 6 deletions cpp/ptq/training/vgg16/export_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test(model, dataloader, crit):
loss = 0.0
class_probs = []
class_preds = []
model.eval()

with torch.no_grad():
for data, labels in dataloader:
Expand Down Expand Up @@ -54,9 +53,12 @@ def test(model, dataloader, crit):
weights = new_state_dict

model.load_state_dict(weights)
model.eval()

# Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team
#model.eval()

jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda"))
jit_model.eval()

testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True,
transform=transforms.Compose([
Expand All @@ -68,10 +70,7 @@ def test(model, dataloader, crit):
shuffle=False, num_workers=2)

crit = torch.nn.CrossEntropyLoss()
test_loss, test_acc = test(model, testing_dataloader, crit)
print("[PTH] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

torch.jit.save(jit_model, "trained_vgg16.jit.pt")
jit_model = torch.jit.load("trained_vgg16.jit.pt")
test_loss, test_acc = test(jit_model, testing_dataloader, crit)
print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
torch.jit.save(jit_model, "trained_vgg16.jit.pt")

0 comments on commit 54a24b3

Please sign in to comment.