Skip to content

Commit

Permalink
jit misc
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jan 25, 2024
1 parent e0e078c commit 56692f7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/mlc_chat/interface/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from mlc_chat.model import MODELS
from mlc_chat.support import logging
from mlc_chat.support.auto_device import device2str
from mlc_chat.support.constants import MLC_CACHE_DIR, MLC_JIT_POLICY, MLC_TEMP_DIR
from mlc_chat.support.constants import (
MLC_CACHE_DIR,
MLC_DSO_SUFFIX,
MLC_JIT_POLICY,
MLC_TEMP_DIR,
)
from mlc_chat.support.style import blue, bold

from .compiler_flags import ModelConfigOverride, OptimizationFlags
Expand All @@ -26,6 +31,11 @@

def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path:
"""Just-in-time compile a MLC-Chat model."""
logger.info(
"%s = %s. Can be one of: ON, OFF, REDO, READONLY",
bold("MLC_JIT_POLICY"),
MLC_JIT_POLICY,
)
if MLC_JIT_POLICY == "OFF":
raise RuntimeError("JIT is disabled by MLC_JIT_POLICY=OFF")

Expand Down Expand Up @@ -64,7 +74,7 @@ def _get_model_config() -> Dict[str, Any]:

def _run_jit(opt: str, overrides: str, device: str, dst: str):
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:
dso_path = os.path.join(tmp_dir, "lib.so")
dso_path = os.path.join(tmp_dir, f"lib.{MLC_DSO_SUFFIX}")
cmd = [
sys.executable,
"-m",
Expand Down
11 changes: 11 additions & 0 deletions python/mlc_chat/support/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ def _get_cache_dir() -> Path:
return result


def _get_dso_suffix() -> str:
if "MLC_DSO_SUFFIX" in os.environ:
return os.environ["MLC_DSO_SUFFIX"]
if sys.platform == "win32":
return ".dll"
if sys.platform == "darwin":
return ".dylib"
return ".so"


MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None)
MLC_CACHE_DIR: Path = _get_cache_dir()
MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON")
MLC_DSO_SUFFIX = _get_dso_suffix()


_check()

0 comments on commit 56692f7

Please sign in to comment.