From 6892888acbff0ef980341d96ea48a38fa3d4ea7c Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Fri, 29 Nov 2024 10:51:08 +0800 Subject: [PATCH] Specify ASCEND NPU for inference. --- fastchat/model/model_adapter.py | 7 +++++++ fastchat/serve/cli.py | 10 ++++++++++ fastchat/serve/model_worker.py | 10 ++++++++++ fastchat/serve/multi_model_worker.py | 10 ++++++++++ 4 files changed, 37 insertions(+) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 9625df6db..5546672d7 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -512,6 +512,13 @@ def add_model_args(parser): help="A single GPU like 1 or multiple GPUs like 0,2", ) parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--npus", + type=str, + default=None, + help="A single NPU like 1 or multiple NPUs like 0,2", + ) + parser.add_argument("--num-npus", type=int, default=1) parser.add_argument( "--max-gpu-memory", type=str, diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py index 78f7f51b1..7afc8d320 100644 --- a/fastchat/serve/cli.py +++ b/fastchat/serve/cli.py @@ -13,6 +13,7 @@ - Type "!!save " to save the conversation history to a json file. - Type "!!load " to load a conversation history from a json file. """ + import argparse import os import re @@ -197,6 +198,15 @@ def main(args): ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if args.npus: + if len(args.npus.split(",")) < args.num_npus: + raise ValueError( + f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!" + ) + if len(args.npus.split(",")) == 1: + import torch_npu + + torch.npu.set_device(int(args.npus)) if args.enable_exllama: exllama_config = ExllamaConfig( max_seq_len=args.exllama_max_seq_len, diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..6ffdf3685 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -1,6 +1,7 @@ """ A model worker that executes the model. """ + import argparse import base64 import gc @@ -351,6 +352,15 @@ def create_model_worker(): f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + if args.npus: + if len(args.npus.split(",")) < args.num_npus: + raise ValueError( + f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!" + ) + if len(args.npus.split(",")) == 1: + import torch_npu + + torch.npu.set_device(int(args.npus)) gptq_config = GptqConfig( ckpt=args.gptq_ckpt or args.model_path, diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py index 5e6266fe0..24cd1c761 100644 --- a/fastchat/serve/multi_model_worker.py +++ b/fastchat/serve/multi_model_worker.py @@ -11,6 +11,7 @@ We recommend using this with multiple Peft models (with `peft` in the name) where all Peft models are trained on the exact same base model. """ + import argparse import asyncio import dataclasses @@ -206,6 +207,15 @@ def create_multi_model_worker(): f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + if args.npus: + if len(args.npus.split(",")) < args.num_npus: + raise ValueError( + f"Larger --num_npus ({args.num_npus}) than --npus {args.npus}!" + ) + if len(args.npus.split(",")) == 1: + import torch_npu + + torch.npu.set_device(int(args.npus)) gptq_config = GptqConfig( ckpt=args.gptq_ckpt or args.model_path,