diff --git a/README.md b/README.md index be62b91..7fcc898 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,18 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/ - Image layout generation (single image + batch process) - *Image masking with categories (single image + batch process) - *Inpaint not masked for ControlNet inpainting on txt2img panel +- `2023/04/29`: [Feature] API has been completely refactored. You can access all features for **single image process** through API. API documentation has been moved to [wiki](https://github.com/continue-revolution/sd-webui-segment-anything/wiki/API). This extension has been significantly refactored on `2023/04/24`. If you wish to revert to older version, please `git checkout 724b4db`. +## TODO + +- [ ] Color selection for mask region and unmask region +- [ ] Batch ControlNet inpainting +- [ ] Only upload mask (Add content to image) +- [ ] "Masked content" +- [ ] Test EditAnything + ## FAQ Thanks for suggestions from [github issues](https://github.com/continue-revolution/sd-webui-segment-anything/issues), [reddit](https://www.reddit.com/r/StableDiffusion/comments/12hkdy8/sd_webui_segment_everything/) and [bilibili](https://www.bilibili.com/video/BV1Tg4y1u73r/) to make this extension better. @@ -170,55 +179,6 @@ Mask by Category batch demo | --- | --- | --- | --- | | ![1NHa6Wc](https://user-images.githubusercontent.com/63914308/234085498-70ca1d4c-cc5a-44d4-adb2-366630e5ce24.png) | ![1NHa6Wc_0_output](https://user-images.githubusercontent.com/63914308/234085495-0bfc4114-3e81-4ace-81d6-0f0f3186df25.png) | ![1NHa6Wc_0_mask](https://user-images.githubusercontent.com/63914308/234085491-8976f46c-2617-47ee-968e-0a9dd479c63a.png) | ![1NHa6Wc_0_blend](https://user-images.githubusercontent.com/63914308/234085503-7e041373-39cd-4f20-8696-986be517f188.png) -## API Support - -### API Usage - -We have added an API endpoint to allow for automated workflows. - -The API utilizes both Segment Anything and GroundingDINO to return masks of all instances of whatever object is specified in the text prompt. - -This is an extension of the existing [Stable Diffusion Web UI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API). - -There are 2 endpoints exposed -- GET `/sam-webui/heartbeat` -- POST `/sam-webui/image-mask` - -The heartbeat endpoint can be used to ensure that the API is up. - -The image-mask endpoint accepts a payload that includes your base64-encoded image. - -Below is an example of how to interface with the API using requests. - -### API Example - -``` -import base64 -import requests -from PIL import Image -from io import BytesIO - -url = "http://127.0.0.1:7860/sam-webui/image-mask" - -def image_to_base64(img_path: str) -> str: - with open(img_path, "rb") as img_file: - img_base64 = base64.b64encode(img_file.read()).decode() - return img_base64 - -payload = { - "image": image_to_base64("IMAGE_FILE_PATH"), - "prompt": "TEXT PROMPT", - "box_threshold": 0.3, - "padding": 30 #Optional param to pad masks -} -res = requests.post(url, json=payload) - -for dct in res.json(): - image_data = base64.b64decode(dct['image']) - image = Image.open(BytesIO(image_data)) - image.show() -``` - ## Contribute Disclaimer: I have not thoroughly tested this extension, so there might be bugs. Bear with me while I'm fixing them :) diff --git a/scripts/api.py b/scripts/api.py index aebead2..fd52ff9 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -1,61 +1,238 @@ +import os from fastapi import FastAPI, Body -from io import BytesIO -import base64 from pydantic import BaseModel -from typing import Any, Optional -import asyncio +from typing import Any, Optional, List import gradio as gr -import os -from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list -from scripts.dino import dino_model_list -from PIL import Image, ImageChops -import base64 +from PIL import Image +import numpy as np + +from modules.api.api import encode_pil_to_base64, decode_base64_to_image +from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask + + +def decode_to_pil(image): + if os.path.exists(image): + return Image.open(image) + elif type(image) is str: + return decode_base64_to_image(image) + elif type(image) is Image.Image: + return image + elif type(image) is np.ndarray: + return Image.fromarray(image) + else: + Exception("Not an image") + + +def encode_to_base64(image): + if type(image) is str: + return image + elif type(image) is Image.Image: + return encode_pil_to_base64(image).decode() + elif type(image) is np.ndarray: + pil = Image.fromarray(image) + return encode_pil_to_base64(pil).decode() + else: + Exception("Invalid type") + def sam_api(_: gr.Blocks, app: FastAPI): - @app.get("/sam-webui/heartbeat") + @app.get("/sam/heartbeat") async def heartbeat(): return { "msg": "Success!" } + + class SamPredictRequest(BaseModel): + sam_model_name: str = "sam_vit_h_4b8939.pth" + input_image: str + sam_positive_points: List[List[float]] = [] + sam_negative_points: List[List[float]] = [] + dino_enabled: bool = False + dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)" + dino_text_prompt: Optional[str] = None + dino_box_threshold: Optional[float] = 0.3 + dino_preview_checkbox: bool = False + dino_preview_boxes_selection: Optional[List[int]] = None + + @app.post("/sam/sam-predict") + async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any: + print(f"SAM API /sam/sam-predict received request") + payload.input_image = decode_to_pil(payload.input_image).convert('RGBA') + sam_output_mask_gallery, sam_message = sam_predict( + payload.sam_model_name, + payload.input_image, + payload.sam_positive_points, + payload.sam_negative_points, + payload.dino_enabled, + payload.dino_model_name, + payload.dino_text_prompt, + payload.dino_box_threshold, + payload.dino_preview_checkbox, + payload.dino_preview_boxes_selection) + print(f"SAM API /sam/sam-predict finished with message: {sam_message}") + result = { + "msg": sam_message, + } + if len(sam_output_mask_gallery) == 9: + result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3])) + result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6])) + result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:])) + return result + + class DINOPredictRequest(BaseModel): + input_image: str + dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)" + text_prompt: str + box_threshold: float = 0.3 + + @app.post("/sam/dino-predict") + async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any: + print(f"SAM API /sam/dino-predict received request") + payload.input_image = decode_to_pil(payload.input_image) + dino_output_img, _, dino_msg = dino_predict( + payload.input_image, + payload.dino_model_name, + payload.text_prompt, + payload.box_threshold) + if "value" in dino_msg: + dino_msg = dino_msg["value"] + else: + dino_msg = "Done" + print(f"SAM API /sam/dino-predict finished with message: {dino_msg}") + return { + "msg": dino_msg, + "image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None, + } + + class DilateMaskRequest(BaseModel): + input_image: str + mask: str + dilate_amount: int = 10 + + @app.post("/sam/dilate-mask") + async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any: + print(f"SAM API /sam/dilate-mask received request") + payload.input_image = decode_to_pil(payload.input_image).convert("RGBA") + payload.mask = decode_to_pil(payload.mask) + dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image))) + print(f"SAM API /sam/dilate-mask finished") + return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]} + + + class AutoSAMConfig(BaseModel): + points_per_side: Optional[int] = 32 + points_per_batch: int = 64 + pred_iou_thresh: float = 0.88 + stability_score_thresh: float = 0.95 + stability_score_offset: float = 1.0 + box_nms_thresh: float = 0.7 + crop_n_layers: int = 0 + crop_nms_thresh: float = 0.7 + crop_overlap_ratio: float = 512 / 1500 + crop_n_points_downscale_factor: int = 1 + min_mask_region_area: int = 0 + + class ControlNetSegRequest(BaseModel): + sam_model_name: str = "sam_vit_h_4b8939.pth" + input_image: str + processor: str = "seg_ofade20k" + processor_res: int = 512 + pixel_perfect: bool = False + resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill + target_W: Optional[int] = None + target_H: Optional[int] = None + + @app.post("/sam/controlnet-seg") + async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...), + autosam_conf: AutoSAMConfig = Body(...)) -> Any: + print(f"SAM API /sam/controlnet-seg received request") + payload.input_image = decode_to_pil(payload.input_image) + cnet_seg_img, cnet_seg_msg = cnet_seg( + payload.sam_model_name, + payload.input_image, + payload.processor, + payload.processor_res, + payload.pixel_perfect, + payload.resize_mode, + payload.target_W, + payload.target_H, + autosam_conf.points_per_side, + autosam_conf.points_per_batch, + autosam_conf.pred_iou_thresh, + autosam_conf.stability_score_thresh, + autosam_conf.stability_score_offset, + autosam_conf.box_nms_thresh, + autosam_conf.crop_n_layers, + autosam_conf.crop_nms_thresh, + autosam_conf.crop_overlap_ratio, + autosam_conf.crop_n_points_downscale_factor, + autosam_conf.min_mask_region_area) + cnet_seg_img = list(map(encode_to_base64, cnet_seg_img)) + print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}") + result = { + "msg": cnet_seg_msg, + } + if len(cnet_seg_img) == 3: + result["blended_images"] = cnet_seg_img[0] + result["random_seg"] = cnet_seg_img[1] + result["edit_anything_control"] = cnet_seg_img[2] + elif len(cnet_seg_img) == 4: + result["sem_presam"] = cnet_seg_img[0] + result["sem_postsam"] = cnet_seg_img[1] + result["blended_presam"] = cnet_seg_img[2] + result["blended_postsam"] = cnet_seg_img[3] + return result - class MaskRequest(BaseModel): - image: str #base64 string containing image - prompt: str - box_threshold: float - padding: Optional[int] = 0 - - - def pil_image_to_base64(img: Image.Image) -> str: - buffered = BytesIO() - img.save(buffered, format="JPEG") - img_base64 = base64.b64encode(buffered.getvalue()).decode() - return img_base64 - - @app.post("/sam-webui/image-mask") - async def process_image(payload: MaskRequest = Body(...)) -> Any: - sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None - dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None - # Decode the base64 image string - img_b64 = base64.b64decode(payload.image) - input_img = Image.open(BytesIO(img_b64)) - #Run DINO and SAM inference to get masks back - masks = sam_predict(sam_model_name, - input_img, - [], - [], - True, - dino_model_name, - payload.prompt, - payload.box_threshold, - None, - None, - gui=False)[0] - if payload.padding: - masks = [dilate_mask(mask, payload.padding)[0] for mask in masks] - # Convert the final PIL image to a base64 string - response = [{"image": pil_image_to_base64(mask)} for mask in masks] - - return response + class CategoryMaskRequest(BaseModel): + sam_model_name: str = "sam_vit_h_4b8939.pth" + processor: str = "seg_ofade20k" + processor_res: int = 512 + pixel_perfect: bool = False + resize_mode: Optional[int] = 1 + target_W: Optional[int] = None + target_H: Optional[int] = None + category: str + input_image: str + + @app.post("/sam/category-mask") + async def api_category_mask(payload: CategoryMaskRequest = Body(...), + autosam_conf: AutoSAMConfig = Body(...)) -> Any: + print(f"SAM API /sam/category-mask received request") + payload.input_image = decode_to_pil(payload.input_image) + category_mask_img, category_mask_msg, resized_input_img = categorical_mask( + payload.sam_model_name, + payload.processor, + payload.processor_res, + payload.pixel_perfect, + payload.resize_mode, + payload.target_W, + payload.target_H, + payload.category, + payload.input_image, + autosam_conf.points_per_side, + autosam_conf.points_per_batch, + autosam_conf.pred_iou_thresh, + autosam_conf.stability_score_thresh, + autosam_conf.stability_score_offset, + autosam_conf.box_nms_thresh, + autosam_conf.crop_n_layers, + autosam_conf.crop_nms_thresh, + autosam_conf.crop_overlap_ratio, + autosam_conf.crop_n_points_downscale_factor, + autosam_conf.min_mask_region_area) + category_mask_img = list(map(encode_to_base64, category_mask_img)) + print(f"SAM API /sam/category-mask finished with message {category_mask_msg}") + result = { + "msg": category_mask_msg, + } + if len(category_mask_img) == 3: + result["blended_image"] = category_mask_img[0] + result["mask"] = category_mask_img[1] + result["masked_image"] = category_mask_img[2] + if resized_input_img is not None: + result["resized_input"] = encode_to_base64(resized_input_img) + return result + try: import modules.script_callbacks as script_callbacks diff --git a/scripts/process_params.py b/scripts/process_params.py index f73b396..cae9254 100644 --- a/scripts/process_params.py +++ b/scripts/process_params.py @@ -22,9 +22,6 @@ def __init__(self, args: Tuple, is_img2img=False): self.output_chosen_mask: int = 0 self.dilation_checkbox: bool = False self.dilation_output_gallery: List[Dict] = None - self.sketch_checkbox: bool = False - self.inpaint_color_sketch = None - self.inpaint_mask_alpha: int = 0 self.init_sam_single_image_process(args) @@ -37,9 +34,6 @@ def init_sam_single_image_process(self, args): self.output_chosen_mask = args[5] self.dilation_checkbox = args[6] self.dilation_output_gallery = args[7] - self.sketch_checkbox = args[8] - self.inpaint_color_sketch = args[9] - self.inpaint_mask_alpha = args[10] def get_input_and_mask(self, mask_blur): @@ -51,14 +45,14 @@ def get_input_and_mask(self, mask_blur): mask = Image.open(self.output_mask_gallery[self.output_chosen_mask + 3]['name']).convert('L') if mask is not None and self.cnet_inpaint_invert: mask = ImageOps.invert(mask) - if self.is_img2img and self.sketch_checkbox and self.inpaint_color_sketch is not None and mask is not None: - alpha = np.expand_dims(np.array(mask) / 255, axis=-1) - image = np.uint8(np.array(self.inpaint_color_sketch) * alpha + np.array(self.input_image) * (1 - alpha)) - mask = ImageEnhance.Brightness(mask).enhance(1 - self.inpaint_mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), self.input_image, mask.filter(blur)).convert("RGB") - else: - image = self.input_image + # if self.is_img2img and self.sketch_checkbox and self.inpaint_color_sketch is not None and mask is not None: + # alpha = np.expand_dims(np.array(mask) / 255, axis=-1) + # image = np.uint8(np.array(self.inpaint_color_sketch) * alpha + np.array(self.input_image) * (1 - alpha)) + # mask = ImageEnhance.Brightness(mask).enhance(1 - self.inpaint_mask_alpha / 100) + # blur = ImageFilter.GaussianBlur(mask_blur) + # image = Image.composite(image.filter(blur), self.input_image, mask.filter(blur)).convert("RGB") + # else: + image = self.input_image return image, mask @@ -66,8 +60,8 @@ def get_input_and_mask(self, mask_blur): class SAMProcessUnit: def __init__(self, args: Tuple, is_img2img=False): self.is_img2img = is_img2img - sam_inpaint_args = args[:11] - args = args[11:] + sam_inpaint_args = args[:8] + args = args[8:] self.sam_inpaint_unit = SAMInpaintUnit(sam_inpaint_args, is_img2img) self.cnet_seg_output_gallery: List[Dict] = None diff --git a/scripts/sam.py b/scripts/sam.py index 7888b1e..b6aa1d2 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -56,7 +56,10 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5): def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): print("Dilation Amount: ", dilation_amt) - mask_image = Image.open(mask_gallery[chosen_mask + 3]['name']) + if isinstance(mask_gallery, dict): + mask_image = Image.open(mask_gallery[chosen_mask + 3]['name']) + else: + mask_image = mask_gallery binary_img = np.array(mask_image.convert('1')) if dilation_amt: mask_image, binary_img = dilate_mask(binary_img, dilation_amt) @@ -135,18 +138,17 @@ def dilate_mask(mask, dilation_amt): return dilated_mask, dilated_binary_img -def create_mask_output(image_np, masks, boxes_filt, gui): +def create_mask_output(image_np, masks, boxes_filt): print("Creating output image") mask_images, masks_gallery, matted_images = [], [], [] boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None for mask in masks: masks_gallery.append(Image.fromarray(np.any(mask, axis=0))) - if gui: - blended_image = show_masks(show_boxes(image_np, boxes_filt), mask) - mask_images.append(Image.fromarray(blended_image)) - image_np_copy = copy.deepcopy(image_np) - image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0]) - matted_images.append(Image.fromarray(image_np_copy)) + blended_image = show_masks(show_boxes(image_np, boxes_filt), mask) + mask_images.append(Image.fromarray(blended_image)) + image_np_copy = copy.deepcopy(image_np) + image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0]) + matted_images.append(Image.fromarray(image_np_copy)) return mask_images + masks_gallery + matted_images @@ -179,7 +181,7 @@ def create_mask_batch_output( def sam_predict(sam_model_name, input_image, positive_points, negative_points, dino_checkbox, dino_model_name, text_prompt, box_threshold, - dino_preview_checkbox, dino_preview_boxes_selection, gui=True): + dino_preview_checkbox, dino_preview_boxes_selection): print("Start SAM Processing") if sam_model_name is None: return [], "SAM model not found. Please download SAM model from extension README." @@ -234,7 +236,7 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, multimask_output=True) masks = masks[:, None, ...] garbage_collect(sam) - return create_mask_output(image_np, masks, boxes_filt, gui), sam_predict_status + sam_predict_result + return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): @@ -364,7 +366,7 @@ def categorical_mask( garbage_collect(sam) if isinstance(outputs, str): return [], outputs, None - output_gallery = create_mask_output(resized_input_image_np, outputs[None, None, ...], None, True) + output_gallery = create_mask_output(resized_input_image_np, outputs[None, None, ...], None) return output_gallery, "Done", resized_input_image_pil @@ -608,8 +610,6 @@ def ui(self, is_img2img): sam_inpaint_upload_enable, sam_cnet_inpaint_invert, sam_cnet_inpaint_idx, sam_input_image, sam_output_mask_gallery, sam_output_chosen_mask, sam_dilation_checkbox, sam_dilation_output_gallery) - sam_sketch_checkbox, sam_inpaint_color_sketch, sam_inpaint_mask_alpha = ui_sketch(sam_input_image, is_img2img) - sam_single_image_process += (sam_sketch_checkbox, sam_inpaint_color_sketch, sam_inpaint_mask_alpha) ui_process += sam_single_image_process with gr.TabItem(label="Batch Process"): @@ -722,8 +722,6 @@ def layout_show(mode): crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx, crop_resized_image, crop_output_gallery, crop_padding, crop_dilation_checkbox, crop_dilation_output_gallery) - crop_sketch_checkbox, crop_inpaint_color_sketch, crop_inpaint_mask_alpha = ui_sketch(crop_resized_image, is_img2img) - crop_single_image_process += (crop_sketch_checkbox, crop_inpaint_color_sketch, crop_inpaint_mask_alpha) ui_process += crop_single_image_process with gr.TabItem(label="Batch Process"):