-
Notifications
You must be signed in to change notification settings - Fork 262
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PTQ example for PyTorch CV - Segment Anything Model (#1464)
- Loading branch information
Showing
7 changed files
with
714 additions
and
0 deletions.
There are no files selected for viewing
68 changes: 68 additions & 0 deletions
68
examples/pytorch/image_recognition/segment_anything/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
Step-by-Step | ||
============ | ||
This document describes the step-by-step instructions for applying post training quantization on Segment Anything Model (SAM) using VOC dataset. | ||
|
||
# Prerequisite | ||
## Environment | ||
```shell | ||
# install dependencies | ||
pip install -r ./requirements.txt | ||
# retrieve SAM model codes and pre-trained weight | ||
pip install git+https://github.com/facebookresearch/segment-anything.git | ||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth | ||
``` | ||
|
||
# PTQ | ||
PTQ example on Segment Anything Model (SAM) using VOC dataset. | ||
|
||
## 1. Prepare VOC dataset | ||
```shell | ||
python download_dataset.py | ||
``` | ||
|
||
## 2. Start PTQ | ||
```shell | ||
bash run_quant.sh --voc_dataset_location=./voc_dataset/VOCdevkit/VOC2012/ --pretrained_weight_location=./sam_vit_b_01ec64.pth | ||
``` | ||
|
||
## 3. Benchmarking | ||
```shell | ||
bash run_benchmark.sh --tuned_checkpoint=./saved_results --voc_dataset_location=./voc_dataset/VOCdevkit/VOC2012/ --int8=True --mode=performance | ||
``` | ||
|
||
# Result | ||
| | Baseline (FP32) | INT8 | ||
| ------------- | ------------- | ------------- | ||
Accuracy | 0.7939 | 0.7849 | ||
|
||
# Saving and Loading Model | ||
|
||
* Saving model: | ||
After tuning with Neural Compressor, we can get neural_compressor.model: | ||
|
||
``` | ||
from neural_compressor import PostTrainingQuantConfig | ||
from neural_compressor import quantization | ||
conf = PostTrainingQuantConfig() | ||
q_model = quantization.fit(model, | ||
conf, | ||
calib_dataloader=val_loader, | ||
eval_func=eval_func) | ||
``` | ||
|
||
Here, `q_model` is the Neural Compressor model class, so it has "save" API: | ||
|
||
```python | ||
q_model.save("Path_to_save_quantized_model") | ||
``` | ||
|
||
* Loading model: | ||
|
||
```python | ||
from neural_compressor.utils.pytorch import load | ||
quantized_model = load(os.path.abspath(os.path.expanduser(args.tuned_checkpoint)), | ||
model, | ||
dataloader=val_loader) | ||
``` | ||
|
||
Please refer to main.py for reference. |
7 changes: 7 additions & 0 deletions
7
examples/pytorch/image_recognition/segment_anything/download_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import torchvision | ||
|
||
print("Downloading VOC dataset") | ||
torchvision.datasets.VOCDetection(root='./voc_dataset', year='2012', image_set ='trainval', download=True) | ||
|
||
|
||
|
173 changes: 173 additions & 0 deletions
173
examples/pytorch/image_recognition/segment_anything/inc_dataset_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from segment_anything import SamPredictor, sam_model_registry | ||
import torchvision | ||
import torch | ||
from PIL import Image | ||
|
||
import numpy as np | ||
import os | ||
import xml.etree.ElementTree as ET | ||
from statistics import mean | ||
from torch.nn.functional import threshold, normalize | ||
import torch.nn.functional as F | ||
from segment_anything.utils.transforms import ResizeLongestSide | ||
from typing import List, Tuple | ||
|
||
# Pad image - based on SAM | ||
def pad_image(x: torch.Tensor, square_length = 1024) -> torch.Tensor: | ||
# C, H, W | ||
h, w = x.shape[-2:] | ||
padh = square_length - h | ||
padw = square_length - w | ||
x = F.pad(x, (0, padw, 0, padh)) | ||
return x | ||
|
||
# Custom dataset | ||
class INC_SAMVOC2012Dataset(object): | ||
def __init__(self, voc_root, type): | ||
self.voc_root = voc_root | ||
self.num_of_data = -1 | ||
self.dataset = {} # Item will be : ["filename", "class_name", [4x bounding boxes coordinates], etc) | ||
self.resizelongestside = ResizeLongestSide(target_length=1024) | ||
pixel_mean = [123.675, 116.28, 103.53] | ||
pixel_std = [58.395, 57.12, 57.375] | ||
self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) | ||
self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) | ||
|
||
# Read through all the samples and output a dictionary | ||
# Key of the dictionary will be idx | ||
# Item of the dictionary will be filename, class id and bounding boxes | ||
annotation_dir = os.path.join(voc_root, "Annotations") | ||
files = os.listdir(annotation_dir) | ||
files = [f for f in files if os.path.isfile(annotation_dir+'/'+f)] #Filter directory | ||
annotation_files = [os.path.join(annotation_dir, x) for x in files] | ||
|
||
# Get the name list of the segmentation files | ||
segmentation_dir = os.path.join(voc_root, "SegmentationObject") | ||
files = os.listdir(segmentation_dir) | ||
files = [f for f in files if os.path.isfile(segmentation_dir+'/'+f)] #Filter directory | ||
segmentation_files = [x for x in files] | ||
|
||
|
||
# Based on the type (train/val) to select data | ||
train_val_dir = os.path.join(voc_root, 'ImageSets/Segmentation/') | ||
if type == 'train': | ||
txt_file_name = 'train.txt' | ||
elif type =='val': | ||
txt_file_name = 'val.txt' | ||
else: | ||
print('Error! Type of dataset should be ''train'' or ''val'' ') | ||
|
||
with open(train_val_dir + txt_file_name, 'r') as f: | ||
permitted_files = [] | ||
for row in f: | ||
permitted_files.append(row.rstrip('\n')) | ||
|
||
for file in annotation_files: | ||
file_name = file.split('/')[-1].split('.xml')[0] | ||
|
||
if not(file_name in permitted_files): | ||
continue #skip the file | ||
|
||
if file_name + '.png' in segmentation_files: # check that if there is any related segmentation file for this annotation | ||
tree = ET.parse(file) | ||
root = tree.getroot() | ||
for child in root: | ||
if child.tag == 'object': | ||
details = [file_name] | ||
for node in child: | ||
if node.tag == 'name': | ||
object_name = node.text | ||
if node.tag == 'bndbox': | ||
for coordinates in node: | ||
if coordinates.tag == 'xmax': | ||
xmax = int(coordinates.text) | ||
if coordinates.tag == 'xmin': | ||
xmin = int(coordinates.text) | ||
if coordinates.tag == 'ymax': | ||
ymax = int(coordinates.text) | ||
if coordinates.tag == 'ymin': | ||
ymin = int(coordinates.text) | ||
boundary = [xmin, ymin, xmax, ymax] | ||
details.append(object_name) | ||
details.append(boundary) | ||
self.num_of_data += 1 | ||
self.dataset[self.num_of_data] = details | ||
|
||
def __len__(self): | ||
return self.num_of_data | ||
|
||
# Preprocess the segmentation mask. Output only 1 object semantic information. | ||
def preprocess_segmentation(self, filename, bounding_box, pad=True): | ||
|
||
#read the semantic mask | ||
segment_mask = Image.open(self.voc_root + 'SegmentationObject/' + filename + '.png') | ||
segment_mask_np = torchvision.transforms.functional.pil_to_tensor(segment_mask) | ||
|
||
#Crop the segmentation based on the bounding box | ||
xmin, ymin = int(bounding_box[0]), int(bounding_box[1]) | ||
xmax, ymax = int(bounding_box[2]), int(bounding_box[3]) | ||
cropped_mask = segment_mask.crop((xmin, ymin, xmax, ymax)) | ||
cropped_mask_np = torchvision.transforms.functional.pil_to_tensor(cropped_mask) | ||
|
||
#Count the majority element | ||
bincount = np.bincount(cropped_mask_np.reshape(-1)) | ||
bincount[0] = 0 #Remove the black pixel | ||
if (bincount.shape[0] >= 256): | ||
bincount[255] = 0 #Remove the white pixel | ||
majority_element = bincount.argmax() | ||
|
||
#Based on the majority element, binary mask the segmentation | ||
segment_mask_np[np.where((segment_mask_np != 0) & (segment_mask_np != majority_element))] = 0 | ||
segment_mask_np[segment_mask_np == majority_element] = 1 | ||
|
||
#Pad the segment mask to 1024x1024 (for batching in dataloader) | ||
if pad: | ||
segment_mask_np = pad_image(segment_mask_np) | ||
|
||
return segment_mask_np | ||
|
||
# Preprocess the image to an appropriate format for SAM | ||
def preprocess_image(self, img): | ||
# ~= predictor.py - set_image() | ||
img = np.array(img) | ||
input_image = self.resizelongestside.apply_image(img) | ||
input_image_torch = torch.as_tensor(input_image, device='cpu') | ||
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous() | ||
input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std #normalize | ||
original_size = img.shape[:2] | ||
input_size = tuple(input_image_torch.shape[-2:]) | ||
|
||
return pad_image(input_image_torch), original_size, input_size | ||
|
||
def __getitem__(self, idx): | ||
data = self.dataset[idx] | ||
filename, classname = data[0], data[1] | ||
bounding_box = data[2] | ||
|
||
# No padding + preprocessing | ||
mask_gt = self.preprocess_segmentation(filename, bounding_box, pad=False) | ||
|
||
image, original_size, input_size = self.preprocess_image(Image.open(self.voc_root + 'JPEGImages/' + filename + '.jpg')) # read the image | ||
prompt = bounding_box # bounding box - input_boxes x1, y1, x2, y2 | ||
training_data = {} | ||
training_data['image'] = image | ||
training_data["original_size"] = original_size | ||
training_data["input_size"] = input_size | ||
training_data["ground_truth_mask"] = mask_gt | ||
training_data["prompt"] = prompt | ||
return (training_data, mask_gt) #data, label | ||
|
||
|
||
class INC_SAMVOC2012Dataloader: | ||
def __init__(self, batch_size, **kwargs): | ||
self.batch_size = batch_size | ||
self.dataset = [] | ||
ds = INC_SAMVOC2012Dataset(kwargs['voc_root'], kwargs['type']) | ||
# operations to add (input_data, label) pairs into self.dataset | ||
for i in range(len(ds)): | ||
self.dataset.append(ds[i]) | ||
|
||
|
||
def __iter__(self): | ||
for input_data, label in self.dataset: | ||
yield input_data, label |
Oops, something went wrong.