Skip to content

Commit

Permalink
feat: enable pytorch xpu support for non-attention models (#2561)
Browse files Browse the repository at this point in the history
XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
  • Loading branch information
dvrogozh authored Oct 14, 2024
1 parent 7a82ddc commit 58848cb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
26 changes: 15 additions & 11 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
Expand Down Expand Up @@ -593,8 +592,14 @@ def fallback(
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")

device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
Expand All @@ -616,18 +621,17 @@ def fallback(
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
if device_count > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
device_count == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
model = model.to(device)

if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
Expand Down
25 changes: 15 additions & 10 deletions server/text_generation_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,14 +558,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
Expand Down Expand Up @@ -630,8 +629,14 @@ def fallback(
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")

device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
Expand All @@ -646,14 +651,14 @@ def fallback(
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
if device_count > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if device_count == 1:
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(
model_id,
Expand Down
5 changes: 5 additions & 0 deletions server/text_generation_server/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def noop(*args, **kwargs):
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
elif hasattr(torch, "xpu") and torch.xpu.is_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
SYSTEM = "cpu"

Expand Down

0 comments on commit 58848cb

Please sign in to comment.