Skip to content

Commit

Permalink
[backend] filter out non-arc devices
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuullll committed Nov 8, 2024
1 parent fc53f1e commit 93602fe
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 57 deletions.
2 changes: 1 addition & 1 deletion WebUI/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "ai-playground",
"private": true,
"version": "1.22.0-beta",
"version": "1.22.1-beta",
"scripts": {
"dev": "cross-env VITE_PLATFORM_TITLE=\"for Local® Dev™ Mode\" vite",
"pack-python": "node build\\pack-python.js .\\package_res",
Expand Down
11 changes: 11 additions & 0 deletions service/device_detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import intel_extension_for_pytorch as ipex

# filter out non-Arc devices
supported_ids = []
for i in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(i)
if 'arc' in props.name.lower():
supported_ids.append(str(i))

print(','.join(supported_ids))
56 changes: 0 additions & 56 deletions service/env_setup.py

This file was deleted.

30 changes: 30 additions & 0 deletions service/web_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
import sys

# Try to filter out unsupported devices
try:
# Create a subprocess to import IPEX and list devices
import subprocess
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
device_detect_script = os.path.join(script_dir, "device_detect.py")

# Run the subprocess
result = subprocess.run(
[sys.executable, device_detect_script],
capture_output=True,
text=True,
)

# Check if the subprocess ran successfully
if result.returncode != 0:
raise Exception(f"Device detection failed: {result.stderr}")

# Get the supported device IDs
supported_ids = result.stdout.strip()
if not supported_ids:
raise Exception("No supported devices found")

# Set the environment variable to filter devices
os.environ["ONEAPI_DEVICE_SELECTOR"] = f"*:{supported_ids}"
print(f"Set ONEAPI_DEVICE_SELECTOR={os.environ['ONEAPI_DEVICE_SELECTOR']}")
except:
pass

# Credit to https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186
# Related issues:
# + https://github.com/XPixelGroup/BasicSR/issues/649
Expand Down

0 comments on commit 93602fe

Please sign in to comment.