Skip to content

Commit

Permalink
cleanup ✨
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Nov 30, 2023
1 parent a1107ee commit 73c359c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
fastapi[all]
llama-cpp-python
transformers
diffusers
8 changes: 7 additions & 1 deletion src/fastserve/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import argparse

from fastserve.utils import get_default_device

from .ssd import FastServeSSD

parser = argparse.ArgumentParser(description="Serve models with FastServe")
parser.add_argument("--model", type=str, required=True, help="Name of the model")
parser.add_argument("--device", type=str, required=False, help="Device")


args = parser.parse_args()

app = None
device = args.device or get_default_device()

if args.model == "ssd-1b":
app = FastServeSSD(device="mps")
app = FastServeSSD(device=device)
else:
raise Exception(f"FastServe.models doesn't implement model={args.model}")

Expand Down
13 changes: 9 additions & 4 deletions src/fastserve/models/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ class PromptRequest(BaseModel):


class FastServeSSD(FastServe):
def __init__(self, batch_size=2, timeout=0.5, device="cuda") -> None:
def __init__(
self, batch_size=2, timeout=0.5, device="cuda", num_inference_steps: int = 1
) -> None:
super().__init__(batch_size, timeout)
self.num_inference_steps = num_inference_steps
self.input_schema = PromptRequest
self.pipe = StableDiffusionXLPipeline.from_pretrained(
"segmind/SSD-1B",
Expand All @@ -31,14 +34,16 @@ def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]:
negative_prompts = [b.negative_prompt for b in batch]

pil_images = self.pipe(
prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=1
prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps,
).images
image_bytes_list = []
for pil_image in pil_images:
image_bytes = io.BytesIO()
pil_image.save(image_bytes, format="JPEG")
image_bytes_list.append(image_bytes)
image_bytes_list.append(image_bytes.getvalue())
return [
StreamingResponse(image_bytes, media_type="image/jpeg")
StreamingResponse(iter([image_bytes]), media_type="image/jpeg")
for image_bytes in image_bytes_list
]
9 changes: 9 additions & 0 deletions src/fastserve/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def get_default_device():
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"

0 comments on commit 73c359c

Please sign in to comment.