diff --git a/requirements.txt b/requirements.txt index b39b56c..a763790 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ fastapi[all] llama-cpp-python +transformers +diffusers diff --git a/src/fastserve/models/__main__.py b/src/fastserve/models/__main__.py index 207e50b..6f44aa1 100644 --- a/src/fastserve/models/__main__.py +++ b/src/fastserve/models/__main__.py @@ -1,15 +1,21 @@ import argparse +from fastserve.utils import get_default_device + from .ssd import FastServeSSD parser = argparse.ArgumentParser(description="Serve models with FastServe") parser.add_argument("--model", type=str, required=True, help="Name of the model") +parser.add_argument("--device", type=str, required=False, help="Device") + args = parser.parse_args() app = None +device = args.device or get_default_device() + if args.model == "ssd-1b": - app = FastServeSSD(device="mps") + app = FastServeSSD(device=device) else: raise Exception(f"FastServe.models doesn't implement model={args.model}") diff --git a/src/fastserve/models/ssd.py b/src/fastserve/models/ssd.py index e5d1b33..16803de 100644 --- a/src/fastserve/models/ssd.py +++ b/src/fastserve/models/ssd.py @@ -15,8 +15,11 @@ class PromptRequest(BaseModel): class FastServeSSD(FastServe): - def __init__(self, batch_size=2, timeout=0.5, device="cuda") -> None: + def __init__( + self, batch_size=2, timeout=0.5, device="cuda", num_inference_steps: int = 1 + ) -> None: super().__init__(batch_size, timeout) + self.num_inference_steps = num_inference_steps self.input_schema = PromptRequest self.pipe = StableDiffusionXLPipeline.from_pretrained( "segmind/SSD-1B", @@ -31,14 +34,16 @@ def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]: negative_prompts = [b.negative_prompt for b in batch] pil_images = self.pipe( - prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=1 + prompt=prompts, + negative_prompt=negative_prompts, + num_inference_steps=self.num_inference_steps, ).images image_bytes_list = [] for pil_image in pil_images: image_bytes = io.BytesIO() pil_image.save(image_bytes, format="JPEG") - image_bytes_list.append(image_bytes) + image_bytes_list.append(image_bytes.getvalue()) return [ - StreamingResponse(image_bytes, media_type="image/jpeg") + StreamingResponse(iter([image_bytes]), media_type="image/jpeg") for image_bytes in image_bytes_list ] diff --git a/src/fastserve/utils.py b/src/fastserve/utils.py new file mode 100644 index 0000000..bc6e194 --- /dev/null +++ b/src/fastserve/utils.py @@ -0,0 +1,9 @@ +import torch + + +def get_default_device(): + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + return "cpu"