Skip to content

Commit

Permalink
Merge pull request #69 from continue-revolution/ControlNet
Browse files Browse the repository at this point in the history
Support API for single image
  • Loading branch information
continue-revolution authored Apr 28, 2023
2 parents dc59b3e + d939fe1 commit d205fb8
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 128 deletions.
58 changes: 9 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/
- Image layout generation (single image + batch process)
- *Image masking with categories (single image + batch process)
- *Inpaint not masked for ControlNet inpainting on txt2img panel
- `2023/04/29`: [Feature] API has been completely refactored. You can access all features for **single image process** through API. API documentation has been moved to [wiki](https://github.com/continue-revolution/sd-webui-segment-anything/wiki/API).

This extension has been significantly refactored on `2023/04/24`. If you wish to revert to older version, please `git checkout 724b4db`.

## TODO

- [ ] Color selection for mask region and unmask region
- [ ] Batch ControlNet inpainting
- [ ] Only upload mask (Add content to image)
- [ ] "Masked content"
- [ ] Test EditAnything

## FAQ

Thanks for suggestions from [github issues](https://github.com/continue-revolution/sd-webui-segment-anything/issues), [reddit](https://www.reddit.com/r/StableDiffusion/comments/12hkdy8/sd_webui_segment_everything/) and [bilibili](https://www.bilibili.com/video/BV1Tg4y1u73r/) to make this extension better.
Expand Down Expand Up @@ -170,55 +179,6 @@ Mask by Category batch demo
| --- | --- | --- | --- |
| ![1NHa6Wc](https://user-images.githubusercontent.com/63914308/234085498-70ca1d4c-cc5a-44d4-adb2-366630e5ce24.png) | ![1NHa6Wc_0_output](https://user-images.githubusercontent.com/63914308/234085495-0bfc4114-3e81-4ace-81d6-0f0f3186df25.png) | ![1NHa6Wc_0_mask](https://user-images.githubusercontent.com/63914308/234085491-8976f46c-2617-47ee-968e-0a9dd479c63a.png) | ![1NHa6Wc_0_blend](https://user-images.githubusercontent.com/63914308/234085503-7e041373-39cd-4f20-8696-986be517f188.png)

## API Support

### API Usage

We have added an API endpoint to allow for automated workflows.

The API utilizes both Segment Anything and GroundingDINO to return masks of all instances of whatever object is specified in the text prompt.

This is an extension of the existing [Stable Diffusion Web UI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API).

There are 2 endpoints exposed
- GET `/sam-webui/heartbeat`
- POST `/sam-webui/image-mask`

The heartbeat endpoint can be used to ensure that the API is up.

The image-mask endpoint accepts a payload that includes your base64-encoded image.

Below is an example of how to interface with the API using requests.

### API Example

```
import base64
import requests
from PIL import Image
from io import BytesIO
url = "http://127.0.0.1:7860/sam-webui/image-mask"
def image_to_base64(img_path: str) -> str:
with open(img_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode()
return img_base64
payload = {
"image": image_to_base64("IMAGE_FILE_PATH"),
"prompt": "TEXT PROMPT",
"box_threshold": 0.3,
"padding": 30 #Optional param to pad masks
}
res = requests.post(url, json=payload)
for dct in res.json():
image_data = base64.b64decode(dct['image'])
image = Image.open(BytesIO(image_data))
image.show()
```

## Contribute

Disclaimer: I have not thoroughly tested this extension, so there might be bugs. Bear with me while I'm fixing them :)
Expand Down
273 changes: 225 additions & 48 deletions scripts/api.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,238 @@
import os
from fastapi import FastAPI, Body
from io import BytesIO
import base64
from pydantic import BaseModel
from typing import Any, Optional
import asyncio
from typing import Any, Optional, List
import gradio as gr
import os
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
from scripts.dino import dino_model_list
from PIL import Image, ImageChops
import base64
from PIL import Image
import numpy as np

from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask


def decode_to_pil(image):
if os.path.exists(image):
return Image.open(image)
elif type(image) is str:
return decode_base64_to_image(image)
elif type(image) is Image.Image:
return image
elif type(image) is np.ndarray:
return Image.fromarray(image)
else:
Exception("Not an image")


def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image).decode()
elif type(image) is np.ndarray:
pil = Image.fromarray(image)
return encode_pil_to_base64(pil).decode()
else:
Exception("Invalid type")


def sam_api(_: gr.Blocks, app: FastAPI):
@app.get("/sam-webui/heartbeat")
@app.get("/sam/heartbeat")
async def heartbeat():
return {
"msg": "Success!"
}

class SamPredictRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
sam_positive_points: List[List[float]] = []
sam_negative_points: List[List[float]] = []
dino_enabled: bool = False
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)"
dino_text_prompt: Optional[str] = None
dino_box_threshold: Optional[float] = 0.3
dino_preview_checkbox: bool = False
dino_preview_boxes_selection: Optional[List[int]] = None

@app.post("/sam/sam-predict")
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/sam-predict received request")
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
sam_output_mask_gallery, sam_message = sam_predict(
payload.sam_model_name,
payload.input_image,
payload.sam_positive_points,
payload.sam_negative_points,
payload.dino_enabled,
payload.dino_model_name,
payload.dino_text_prompt,
payload.dino_box_threshold,
payload.dino_preview_checkbox,
payload.dino_preview_boxes_selection)
print(f"SAM API /sam/sam-predict finished with message: {sam_message}")
result = {
"msg": sam_message,
}
if len(sam_output_mask_gallery) == 9:
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
return result

class DINOPredictRequest(BaseModel):
input_image: str
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)"
text_prompt: str
box_threshold: float = 0.3

@app.post("/sam/dino-predict")
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/dino-predict received request")
payload.input_image = decode_to_pil(payload.input_image)
dino_output_img, _, dino_msg = dino_predict(
payload.input_image,
payload.dino_model_name,
payload.text_prompt,
payload.box_threshold)
if "value" in dino_msg:
dino_msg = dino_msg["value"]
else:
dino_msg = "Done"
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}")
return {
"msg": dino_msg,
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
}

class DilateMaskRequest(BaseModel):
input_image: str
mask: str
dilate_amount: int = 10

@app.post("/sam/dilate-mask")
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any:
print(f"SAM API /sam/dilate-mask received request")
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA")
payload.mask = decode_to_pil(payload.mask)
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image)))
print(f"SAM API /sam/dilate-mask finished")
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]}


class AutoSAMConfig(BaseModel):
points_per_side: Optional[int] = 32
points_per_batch: int = 64
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
crop_overlap_ratio: float = 512 / 1500
crop_n_points_downscale_factor: int = 1
min_mask_region_area: int = 0

class ControlNetSegRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill
target_W: Optional[int] = None
target_H: Optional[int] = None

@app.post("/sam/controlnet-seg")
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/controlnet-seg received request")
payload.input_image = decode_to_pil(payload.input_image)
cnet_seg_img, cnet_seg_msg = cnet_seg(
payload.sam_model_name,
payload.input_image,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img))
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}")
result = {
"msg": cnet_seg_msg,
}
if len(cnet_seg_img) == 3:
result["blended_images"] = cnet_seg_img[0]
result["random_seg"] = cnet_seg_img[1]
result["edit_anything_control"] = cnet_seg_img[2]
elif len(cnet_seg_img) == 4:
result["sem_presam"] = cnet_seg_img[0]
result["sem_postsam"] = cnet_seg_img[1]
result["blended_presam"] = cnet_seg_img[2]
result["blended_postsam"] = cnet_seg_img[3]
return result

class MaskRequest(BaseModel):
image: str #base64 string containing image
prompt: str
box_threshold: float
padding: Optional[int] = 0


def pil_image_to_base64(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64

@app.post("/sam-webui/image-mask")
async def process_image(payload: MaskRequest = Body(...)) -> Any:
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
# Decode the base64 image string
img_b64 = base64.b64decode(payload.image)
input_img = Image.open(BytesIO(img_b64))
#Run DINO and SAM inference to get masks back
masks = sam_predict(sam_model_name,
input_img,
[],
[],
True,
dino_model_name,
payload.prompt,
payload.box_threshold,
None,
None,
gui=False)[0]
if payload.padding:
masks = [dilate_mask(mask, payload.padding)[0] for mask in masks]
# Convert the final PIL image to a base64 string
response = [{"image": pil_image_to_base64(mask)} for mask in masks]

return response
class CategoryMaskRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1
target_W: Optional[int] = None
target_H: Optional[int] = None
category: str
input_image: str

@app.post("/sam/category-mask")
async def api_category_mask(payload: CategoryMaskRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/category-mask received request")
payload.input_image = decode_to_pil(payload.input_image)
category_mask_img, category_mask_msg, resized_input_img = categorical_mask(
payload.sam_model_name,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
payload.category,
payload.input_image,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
category_mask_img = list(map(encode_to_base64, category_mask_img))
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}")
result = {
"msg": category_mask_msg,
}
if len(category_mask_img) == 3:
result["blended_image"] = category_mask_img[0]
result["mask"] = category_mask_img[1]
result["masked_image"] = category_mask_img[2]
if resized_input_img is not None:
result["resized_input"] = encode_to_base64(resized_input_img)
return result


try:
import modules.script_callbacks as script_callbacks
Expand Down
Loading

0 comments on commit d205fb8

Please sign in to comment.