Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM finetune and inference #34

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b6ab463
add SAM training and inference model
Aug 21, 2023
e6a9d5b
add inference one image
Sep 15, 2023
0e40651
decouple image , box and mask in resize and pad transform
Oct 8, 2023
5c35b4f
add lite inference
Oct 10, 2023
c52bb5d
move sam examples to demo
Oct 13, 2023
a567520
move sam from research to official/cv
Oct 16, 2023
d784688
update comment for lite-inference
Oct 16, 2023
54d3eab
add O2 for eval.py
Oct 19, 2023
76868c2
fix bug of maskmiou calculation
Oct 19, 2023
e179f42
add sa-1b dataset
Nov 16, 2023
364c856
add model arts config
Nov 23, 2023
ce5e3bc
hack work_root
Nov 23, 2023
1202897
inference -> O0
Nov 24, 2023
49e7667
add ckpt download link
Nov 29, 2023
d387717
workaround of layernorm2d
Nov 29, 2023
1e7a945
set drop_over_flow true
Dec 5, 2023
0ade063
add text prompt inference(blip2)
Dec 7, 2023
8ac6eba
add text prompt train(blip2)
Dec 9, 2023
90d453b
fix bug of sample with no large mask
Dec 9, 2023
1370b64
add filter_text_encoder button in saving ckpt
Dec 11, 2023
a733213
adapt model_wrapper to support text and box input
Dec 11, 2023
15e4b6d
add cost time log
Dec 12, 2023
110f102
add text-prompt readme
Dec 12, 2023
c605461
add clip text_encoder
Dec 13, 2023
1dfd983
1.fixbug of unfreezing text_proj; 2.add dimension projection for blip2
Dec 14, 2023
02d987e
change default O2 to O0 to prevent failure of clip training
Dec 14, 2023
4d58984
update readme and text-inference script
Dec 15, 2023
9e384f7
add cloud training config
Dec 15, 2023
717aa48
update cloud training dataset config
Dec 15, 2023
8c5ae68
update text-inference
Dec 16, 2023
8eb16fa
add batch inference demos
Dec 20, 2023
b22ac77
add point finetune
Jan 2, 2024
85ec292
add point finetune: remove duplicate graph complie
Jan 3, 2024
e23ea03
add point finetune: add iter training to config
Jan 3, 2024
367f0cf
add point finetune: add point inference
Jan 4, 2024
0a23f9f
add point finetune: fix get new point bug, add point transform
Jan 4, 2024
abc616e
add point finetune: to amp_level O0
Jan 4, 2024
88ec8b2
add point finetune: add loss scale
Jan 4, 2024
82d5c91
add point finetune: change from grad sum to mean
Jan 5, 2024
53a779d
add point finetune: fix bug of unscale order
Jan 8, 2024
a1929ce
add point finetune: lr schedule from dynamic to pre-computed
Jan 8, 2024
81cc03d
add point finetune: prevent duplicate graph compile of grad reducer
Jan 8, 2024
51c481e
add point finetune: update config of lr_scheduler
Jan 9, 2024
37588ed
add point finetune: fix bug of valid_boxes
Jan 9, 2024
84f697b
add point finetune: fix bug of all_finite for distributed training
Jan 10, 2024
eb04165
update readme and inference demo
Jan 11, 2024
0fee2a9
update readme of blip2 and clip
Jan 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions official/cv/segment-anything/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Segment Anything

The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

## Installation

The code requires `python>=3.7` and supports Ascend platform, some important pre-dependencies is:
1. mindspore: Please follow the instructions [here](https://www.mindspore.cn/install) to install mindspore dependencies.
2. mindformers: please follow the instructions [here](https://gitee.com/mindspore/mindformers) using source code to install mindformers,

Clone the repository locally and install with

```shell
git clone https://github.com/Mark-ZhouWX/models.git
cd models/official/cv/segment-anything
pip install -r requirements.txt
```

## Finetune

Finetune is a popular method that adapts large pretrained model to specific downstream tasks. Currently, finetune with box-prompt and text-prompt is supported.

*Note that finetune of SAM is not open-source at [official implementation of pytorch](https://github.com/facebookresearch/segment-anything).
In this repository, finetune is an experimental function and still under improvement*

### Finetune with box-prompt
The bounding boxes are used as prompt input to predict mask.
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective.

The bellowing shows the mask quality before and after finetune.


| pretrained_model | dataset | epochs | mIOU | ckpt |
|:----------------:| -------- |:-------------:|------|--------------------------------------------------------------------------------------------------------------|
| sam-vit-b | COCO2017 | 0 (zero-shot) | 74.5 | |
| sam-vit-b | COCO2017 | 20 | 80.2 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_coco-a9b75828.ckpt) |
| sam-vit-b | FLARE22 | 0 (zero-shot) | 78.6 | |
| sam-vit-b | FLARE22 | 10 | 87.4 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_flare-ace06cc2.ckpt) |

A machine with **32G ascend memory** is required for box-prompt finetune.

for standalone finetune of COCO dataset, please run:
```shell
python train.py -c configs/coco_box_finetune.yaml -o amp_level=O2
```

for distributed finetune of COCO dataset, please run:
```shell
mpirun --allow-run-as-root -n 8 python train.py -c configs/coco_box_finetune.yaml -o amp_level=O2
```
the fine-tuned model will be saved at the work_root specified in `configs/coco_box_finetune.yaml`. to eval the model, please run:
```shell
python eval.py -c configs/coco_box_finetune.yaml -o amp_level=O2 network.model.checkpoint=your/path/to/ckpt
```
for a fast single image inference, please run,
```shell
python inference.py --amp_level=O2 --checkpoint=your/path/to/ckpt
```

The original FLARE22 dataset contains image in 3D format and ground truth labelled as instance segmentation ids. Run

```shell
python scripts/preprocess_CT_MR_dataset.py
```

to preprocess it to the format of 2D RGB image and binary mask

The following steps are similar to COCO dataset finetune, please refer to the aforementioned description.

Here are the examples of segmentation result predicted by box-prompt fine-tuned SAM:

<div align="center">
<img src="images/coco_bear.jpg" height="350" />

<img src="images/flare_organ.jpg" height="350" />
</div>

<p align="center">
<em> COCO2017 image example</em>


<em> FLARE22 image example </em>
</p>

### Finetune with point-prompt
The point in addition to the previous-step-output mask are used as prompt input to predict mask.
We follow an iterative interactive training schedule described in the official SAM paper. First a foreground point is sampled uniformly from the ground truth mask. After making a prediction,
subsequent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is a foreground or background if the error region is a false negative or false positive.
The mask prediction from the previous iteration is used as an additional prompt. In order to encourage the model to benefit from the supplied mask, several more iterations are used where no additional points are sampled.
The total iteration number and the position where mask-only iterations are inserted is configurable.

Since the original training dataset (SA-1B) is almost of common objects, we use a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/) (preprocess the raw dataset as mentioned in the last chapter) for the finetune experiment.
We note that SAM model express strong zero-shot ability and the finetune process may learn mainly the labelling bias for most downstream datasets.

for standalone finetune of FLARE22 dataset, please run:
```shell
python train.py -c configs/sa1b_point_finetune.yaml
```

for distributed finetune of FLARE22 dataset, please run:
```shell
mpirun --allow-run-as-root -n 4 python train.py -c configs/sa1b_point_finetune.yaml
```

the fine-tuned model will be saved at the work_root specified in `configs/sa1b_point_finetune.yaml`. For a fast single image inference, please run,

```shell
python point_inference.py --checkpoint=your/path/to/ckpt
```

Below is an experimental result batch-prompted with 5 points and the model is trained at scale `vit_b`. The checkpoint can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_point_finetune_flare-898ae8f6.ckpt).
<div align="center">
<img alt="img.png" src="images/tumor2_5point.png" width="600"/>
</div>

Explore more interesting applications such as iterative positive and negative points prompting described in the following Demo Chapter.

### Finetune with text-prompt
*Note again that text-to-mask finetune is exploratory and not robust, and the official pytorch code is not release yet.*


The training procedure described in the official SAM paper is quite interesting that does not require new text annotation. Specifically, for each manually collected mask with area larger than 1002 we extract the CLIP image embedding. Then, during training, we prompt SAM
with the extracted CLIP image embeddings as text prompt input. At inference time we run text through CLIP’s text encoder and then give the resulting text embedding as a prompt to SAM

The key that make the training procedure work is that CLIP’s image embeddings are trained to align with its text embeddings.

This repository provides an implementation of text-to-mask finetune referring to the model structure and training procedure described in the official SAM paper and introduces a stronger multimodal encoder BLIP2 in addition to CLIP.

A machine with **64G ascend memory** is required for text-prompt finetune.

First download SA-1B dataset and put it under `${project_root}/datasets/sa-1b`.

for standalone finetune of SA-1B dataset with BLIP2 (CLIP is similar), please run:
```shell
python train.py -c configs/sa1b_text_finetune_blip2.yaml
```
the BLIP2 checkpoint and bert vocabulary.txt will be automatically downloaded at `./checkpoint_download/`

for distributed finetune, please run:
```shell
mpirun --allow-run-as-root -n 8 python train.py -c configs/sa1b_text_finetune_blip2.yaml
```
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_text_finetune.yaml`. For a fast single image inference, please run,

```shell
python text_inference.py --checkpoint=your/path/to/ckpt --text-prompt your_prompt
```

Below are some zero-shot experimental result prompted with `floor` and `buildings`. The checkpoint fine-tuned with BLIP2 can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_text_finetune_sa1b_10k-972de39e.ckpt). _Note that the model is trained with limited data and the smallest SAM type `vit_b`._

<div align="center">
<img src="images/dengta-floor.png" height="350" />

<img src="images/dengta-buildings.png" height="350" />
</div>

<p align="center">
<em> prompt: floor</em>


<em> prompt: buildings </em>
</p>

Try more prompts like `sky` or `trees` etc.

## Demo

First download the weights ([sam_vit_b](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_b-35e4849c.ckpt), [sam_vit_l](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_l-1b460f38.ckpt), [sam_vit_h](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_h-c72f8ba1.ckpt)) and put them under `${project_root}/models` directory.
There are two recommended ways to use sam.

### Using sam with prompts

#### predict one object at one time

1. points

SAM predicts object masks given prompts that indicate the desired object. if a point prompt is given, three plausible masks are generated.

```shell
python demo/inference_with_promts.py --prompt-type point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<p float="left">
<img src=images/truck_mask1.png width="400"/><img src=images/truck_mask2.png width="400"/><img src=images/truck_mask3.png width="400"/>
</p>

If a prompt with two points is given, one plausible mask is generated instead of 3 because of less ambiguity compared to one point prompt.
The star in green and red denotes positive and negtive point, respectively.

<div align="center">
<img alt="img.png" src="images/truck_two_point.png" width="600"/>
</div>

2. one box

If a box prompt is given, one plausible masks is generated.

```shell
python demo/inference_with_promts.py --prompt-type box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_box.png"/>
</div>

3. one box and one point

If a prompt with both a box and a point is given, one plausible mask is generated.

```shell
python demo/inference_with_promts.py --prompt-type point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_point_box.png"/>
</div>

#### predict multiple objects at one time in a batch way

1. batch point

```shell
python demo/inference_with_promts.py --prompt-type batch_point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" src="images/truck_batch_point.png" width="600"/>
</div>

2. batch box

```shell
python demo/inference_with_promts.py --prompt-type batch_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_batch_box.png"/>
</div>

3. batch box and point

```shell
python demo/inference_with_promts.py --prompt-type batch_point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_batch_point_box.png"/>
</div>

See `python demo/inference_with_promts.py --help` to explore more custom settings.

### Using sam with Automatic Mask Generation(AMG)

Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. AMG works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

```shell
python demo/inference_with_amg.py --model-type vit_h
```

<div align="center">
<img src="images/dengta.jpg" height="350" />

<img src="images/dengta-amg-vith.png" height="350" />
</div>

See `python demo/inference_with_amg.py --help` to explore more custom settings.
89 changes: 89 additions & 0 deletions official/cv/segment-anything/box_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import os

import cv2
import numpy as np

import mindspore as ms

from segment_anything.build_sam import sam_model_registry
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad
import matplotlib.pyplot as plt

from segment_anything.utils.utils import Timer
from segment_anything.utils.visualize import show_mask, show_box


def infer(args):
ms.context.set_context(mode=args.mode, device_target=args.device)

# Step1: data preparation
with Timer('preprocess'):
transform_list = [
ImageResizeAndPad(target_size=1024, apply_mask=False),
ImageNorm(),
]
transform_pipeline = TransformPipeline(transform_list)

image_path = args.image_path
image_np = cv2.imread(image_path)
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
boxes_np = np.array([[425, 600, 700, 875]])

transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
# batch_size for speed test
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023
boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4

# Step2: inference
with Timer('model inference'):
with Timer('load weight and build net'):
network = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
ms.amp.auto_mixed_precision(network=network, amp_level=args.amp_level)
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024)

with Timer('Second time inference'):
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024)

# Step3: post-process
with Timer('post-process'):
mask_logits = mask_logits.asnumpy()[0, 0] > 0.0
mask_logits = mask_logits.astype(np.uint8)
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])),
interpolation=cv2.INTER_CUBIC)

# Step4: visualize
plt.imshow(image_np)
show_box(boxes_np[0], plt.gca())
show_mask(final_mask, plt.gca())
plt.savefig(args.image_path + '_infer.jpg')
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
parser.add_argument(
"--model-type",
type=str,
default='vit_b',
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']",
)

parser.add_argument(
"--checkpoint",
type=str,
default='./models/sam_vit_b-35e4849c.ckpt',
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)

parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
parser.add_argument("--amp_level", type=str, default="O0", help="auto mixed precision level O0, O2.")
parser.add_argument("--mode", type=int, default=0, help="MindSpore context mode. 0 for graph, 1 for pynative.")

args = parser.parse_args()
print(args)
infer(args)
Loading