Skip to content

Commit

Permalink
refactoring device parameter for HF support
Browse files Browse the repository at this point in the history
  • Loading branch information
antunsz committed Mar 18, 2024
1 parent 5c8ad25 commit 55cf7bf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/fastserve/models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
app = ServeHuggingFace(
model_name=args.model_name,
use_gpu=True if args.use_gpu else False,
device=device,
device="cuda" if args.use_gpu else device,
timeout=args.timeout,
batch_size=args.batch_size,
)
Expand Down
5 changes: 3 additions & 2 deletions src/fastserve/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ class PromptRequest(BaseModel):


class ServeHuggingFace(FastServe):
def __init__(self, model_name: str = None, use_gpu: bool = False, **kwargs):
def __init__(self, model_name: str = None, use_gpu: bool = False, device="cpu", **kwargs):
# Determine execution mode from environment or explicit parameter
self.use_gpu = use_gpu or os.getenv("USE_GPU", "false").lower() in ["true", "1"]
self.device = device

# HF authentication
hf_token = os.getenv("HUGGINGFACE_TOKEN")
Expand Down Expand Up @@ -74,7 +75,7 @@ def __call__(self, request: PromptRequest) -> Any:
inputs = self.tokenizer.encode(request.prompt, return_tensors="pt")

if self.use_gpu:
inputs = inputs.to("cuda")
inputs = inputs.to(self.device)

output = self.model.generate(
inputs,
Expand Down

0 comments on commit 55cf7bf

Please sign in to comment.