From 55cf7bf858f52da4dd218e7aa5a8e6a8dd765e66 Mon Sep 17 00:00:00 2001 From: Antunes Date: Mon, 18 Mar 2024 09:17:22 -0300 Subject: [PATCH] refactoring device parameter for HF support --- src/fastserve/models/__main__.py | 2 +- src/fastserve/models/huggingface.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fastserve/models/__main__.py b/src/fastserve/models/__main__.py index eb56dc0..52e12c6 100644 --- a/src/fastserve/models/__main__.py +++ b/src/fastserve/models/__main__.py @@ -77,7 +77,7 @@ app = ServeHuggingFace( model_name=args.model_name, use_gpu=True if args.use_gpu else False, - device=device, + device="cuda" if args.use_gpu else device, timeout=args.timeout, batch_size=args.batch_size, ) diff --git a/src/fastserve/models/huggingface.py b/src/fastserve/models/huggingface.py index ddc64af..217e9df 100644 --- a/src/fastserve/models/huggingface.py +++ b/src/fastserve/models/huggingface.py @@ -20,9 +20,10 @@ class PromptRequest(BaseModel): class ServeHuggingFace(FastServe): - def __init__(self, model_name: str = None, use_gpu: bool = False, **kwargs): + def __init__(self, model_name: str = None, use_gpu: bool = False, device="cpu", **kwargs): # Determine execution mode from environment or explicit parameter self.use_gpu = use_gpu or os.getenv("USE_GPU", "false").lower() in ["true", "1"] + self.device = device # HF authentication hf_token = os.getenv("HUGGINGFACE_TOKEN") @@ -74,7 +75,7 @@ def __call__(self, request: PromptRequest) -> Any: inputs = self.tokenizer.encode(request.prompt, return_tensors="pt") if self.use_gpu: - inputs = inputs.to("cuda") + inputs = inputs.to(self.device) output = self.model.generate( inputs,