Skip to content

Commit

Permalink
refactor API
Browse files Browse the repository at this point in the history
  • Loading branch information
continue-revolution committed Apr 28, 2023
1 parent fd80c25 commit 91d23c5
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 61 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/
- Image layout generation (single image + batch process)
- *Image masking with categories (single image + batch process)
- *Inpaint not masked for ControlNet inpainting on txt2img panel
- `2023/04/29`: [Feature] API has been completely refactored. You can access all features for **single image process** through API. API documentation has been moved to [wiki](https://github.com/continue-revolution/sd-webui-segment-anything/wiki/API).

This extension has been significantly refactored on `2023/04/24`. If you wish to revert to older version, please `git checkout 724b4db`.

Expand All @@ -23,7 +24,7 @@ This extension has been significantly refactored on `2023/04/24`. If you wish to
- [ ] Color selection for mask region and unmask region
- [ ] Batch ControlNet inpainting
- [ ] Only upload mask (Add content to image)
- [ ] What does "Masked content" mean?
- [ ] "Masked content"
- [ ] Test EditAnything

## FAQ
Expand Down Expand Up @@ -186,7 +187,7 @@ We have added an API endpoint to allow for automated workflows.

The API utilizes both Segment Anything and GroundingDINO to return masks of all instances of whatever object is specified in the text prompt.

This is an extension of the existing [Stable Diffusion Web UI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API).
This is an extension of the existing [Stable Diffusion WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API).

There are 2 endpoints exposed
- GET `/sam-webui/heartbeat`
Expand Down
273 changes: 225 additions & 48 deletions scripts/api.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,238 @@
import os
from fastapi import FastAPI, Body
from io import BytesIO
import base64
from pydantic import BaseModel
from typing import Any, Optional
import asyncio
from typing import Any, Optional, List
import gradio as gr
import os
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
from scripts.dino import dino_model_list
from PIL import Image, ImageChops
import base64
from PIL import Image
import numpy as np

from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask


def decode_to_pil(image):
if os.path.exists(image):
return Image.open(image)
elif type(image) is str:
return decode_base64_to_image(image)
elif type(image) is Image.Image:
return image
elif type(image) is np.ndarray:
return Image.fromarray(image)
else:
Exception("Not an image")


def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image).decode()
elif type(image) is np.ndarray:
pil = Image.fromarray(image)
return encode_pil_to_base64(pil).decode()
else:
Exception("Invalid type")


def sam_api(_: gr.Blocks, app: FastAPI):
@app.get("/sam-webui/heartbeat")
@app.get("/sam/heartbeat")
async def heartbeat():
return {
"msg": "Success!"
}

class SamPredictRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
sam_positive_points: List[List[float]] = []
sam_negative_points: List[List[float]] = []
dino_enabled: bool = False
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)"
dino_text_prompt: Optional[str] = None
dino_box_threshold: Optional[float] = 0.3
dino_preview_checkbox: bool = False
dino_preview_boxes_selection: Optional[List[int]] = None

@app.post("/sam/sam-predict")
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/sam-predict received request")
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
sam_output_mask_gallery, sam_message = sam_predict(
payload.sam_model_name,
payload.input_image,
payload.sam_positive_points,
payload.sam_negative_points,
payload.dino_enabled,
payload.dino_model_name,
payload.dino_text_prompt,
payload.dino_box_threshold,
payload.dino_preview_checkbox,
payload.dino_preview_boxes_selection)
print(f"SAM API /sam/sam-predict finished with message: {sam_message}")
result = {
"msg": sam_message,
}
if len(sam_output_mask_gallery) == 9:
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
return result

class DINOPredictRequest(BaseModel):
input_image: str
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)"
text_prompt: str
box_threshold: float = 0.3

@app.post("/sam/dino-predict")
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/dino-predict received request")
payload.input_image = decode_to_pil(payload.input_image)
dino_output_img, _, dino_msg = dino_predict(
payload.input_image,
payload.dino_model_name,
payload.text_prompt,
payload.box_threshold)
if "value" in dino_msg:
dino_msg = dino_msg["value"]
else:
dino_msg = "Done"
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}")
return {
"msg": dino_msg,
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
}

class DilateMaskRequest(BaseModel):
input_image: str
mask: str
dilate_amount: int = 10

@app.post("/sam/dilate-mask")
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any:
print(f"SAM API /sam/dilate-mask received request")
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA")
payload.mask = decode_to_pil(payload.mask)
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image)))
print(f"SAM API /sam/dilate-mask finished")
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]}


class AutoSAMConfig(BaseModel):
points_per_side: Optional[int] = 32
points_per_batch: int = 64
pred_iou_thresh: float = 0.88
stability_score_thresh: float = 0.95
stability_score_offset: float = 1.0
box_nms_thresh: float = 0.7
crop_n_layers: int = 0
crop_nms_thresh: float = 0.7
crop_overlap_ratio: float = 512 / 1500
crop_n_points_downscale_factor: int = 1
min_mask_region_area: int = 0

class ControlNetSegRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
input_image: str
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill
target_W: Optional[int] = None
target_H: Optional[int] = None

@app.post("/sam/controlnet-seg")
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/controlnet-seg received request")
payload.input_image = decode_to_pil(payload.input_image)
cnet_seg_img, cnet_seg_msg = cnet_seg(
payload.sam_model_name,
payload.input_image,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img))
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}")
result = {
"msg": cnet_seg_msg,
}
if len(cnet_seg_img) == 3:
result["blended_images"] = cnet_seg_img[0]
result["random_seg"] = cnet_seg_img[1]
result["edit_anything_control"] = cnet_seg_img[2]
elif len(cnet_seg_img) == 4:
result["sem_presam"] = cnet_seg_img[0]
result["sem_postsam"] = cnet_seg_img[1]
result["blended_presam"] = cnet_seg_img[2]
result["blended_postsam"] = cnet_seg_img[3]
return result

class MaskRequest(BaseModel):
image: str #base64 string containing image
prompt: str
box_threshold: float
padding: Optional[int] = 0


def pil_image_to_base64(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64

@app.post("/sam-webui/image-mask")
async def process_image(payload: MaskRequest = Body(...)) -> Any:
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
# Decode the base64 image string
img_b64 = base64.b64decode(payload.image)
input_img = Image.open(BytesIO(img_b64))
#Run DINO and SAM inference to get masks back
masks = sam_predict(sam_model_name,
input_img,
[],
[],
True,
dino_model_name,
payload.prompt,
payload.box_threshold,
None,
None,
gui=False)[0]
if payload.padding:
masks = [dilate_mask(mask, payload.padding)[0] for mask in masks]
# Convert the final PIL image to a base64 string
response = [{"image": pil_image_to_base64(mask)} for mask in masks]

return response
class CategoryMaskRequest(BaseModel):
sam_model_name: str = "sam_vit_h_4b8939.pth"
processor: str = "seg_ofade20k"
processor_res: int = 512
pixel_perfect: bool = False
resize_mode: Optional[int] = 1
target_W: Optional[int] = None
target_H: Optional[int] = None
category: str
input_image: str

@app.post("/sam/category-mask")
async def api_category_mask(payload: CategoryMaskRequest = Body(...),
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
print(f"SAM API /sam/category-mask received request")
payload.input_image = decode_to_pil(payload.input_image)
category_mask_img, category_mask_msg, resized_input_img = categorical_mask(
payload.sam_model_name,
payload.processor,
payload.processor_res,
payload.pixel_perfect,
payload.resize_mode,
payload.target_W,
payload.target_H,
payload.category,
payload.input_image,
autosam_conf.points_per_side,
autosam_conf.points_per_batch,
autosam_conf.pred_iou_thresh,
autosam_conf.stability_score_thresh,
autosam_conf.stability_score_offset,
autosam_conf.box_nms_thresh,
autosam_conf.crop_n_layers,
autosam_conf.crop_nms_thresh,
autosam_conf.crop_overlap_ratio,
autosam_conf.crop_n_points_downscale_factor,
autosam_conf.min_mask_region_area)
category_mask_img = list(map(encode_to_base64, category_mask_img))
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}")
result = {
"msg": category_mask_msg,
}
if len(category_mask_img) == 3:
result["blended_image"] = category_mask_img[0]
result["mask"] = category_mask_img[1]
result["masked_image"] = category_mask_img[2]
if resized_input_img is not None:
result["resized_input"] = encode_to_base64(resized_input_img)
return result


try:
import modules.script_callbacks as script_callbacks
Expand Down
24 changes: 13 additions & 11 deletions scripts/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5):

def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
if isinstance(mask_gallery, dict):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
else:
mask_image = mask_gallery
binary_img = np.array(mask_image.convert('1'))
if dilation_amt:
mask_image, binary_img = dilate_mask(binary_img, dilation_amt)
Expand Down Expand Up @@ -135,18 +138,17 @@ def dilate_mask(mask, dilation_amt):
return dilated_mask, dilated_binary_img


def create_mask_output(image_np, masks, boxes_filt, gui):
def create_mask_output(image_np, masks, boxes_filt):
print("Creating output image")
mask_images, masks_gallery, matted_images = [], [], []
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
for mask in masks:
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
if gui:
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
return mask_images + masks_gallery + matted_images


Expand Down Expand Up @@ -179,7 +181,7 @@ def create_mask_batch_output(

def sam_predict(sam_model_name, input_image, positive_points, negative_points,
dino_checkbox, dino_model_name, text_prompt, box_threshold,
dino_preview_checkbox, dino_preview_boxes_selection, gui=True):
dino_preview_checkbox, dino_preview_boxes_selection):
print("Start SAM Processing")
if sam_model_name is None:
return [], "SAM model not found. Please download SAM model from extension README."
Expand Down Expand Up @@ -234,7 +236,7 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
multimask_output=True)
masks = masks[:, None, ...]
garbage_collect(sam)
return create_mask_output(image_np, masks, boxes_filt, gui), sam_predict_status + sam_predict_result
return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result


def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
Expand Down Expand Up @@ -364,7 +366,7 @@ def categorical_mask(
garbage_collect(sam)
if isinstance(outputs, str):
return [], outputs, None
output_gallery = create_mask_output(resized_input_image_np, outputs[None, None, ...], None, True)
output_gallery = create_mask_output(resized_input_image_np, outputs[None, None, ...], None)
return output_gallery, "Done", resized_input_image_pil


Expand Down

0 comments on commit 91d23c5

Please sign in to comment.