Skip to content

Commit

Permalink
Patch out the detection of MKL if requested
Browse files Browse the repository at this point in the history
  • Loading branch information
ashao committed Apr 2, 2024
1 parent 6f800b1 commit b3bc044
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
17 changes: 17 additions & 0 deletions smartsim/_core/_cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def build_redis_ai(
torch_dir: t.Union[str, Path, None] = None,
libtf_dir: t.Union[str, Path, None] = None,
verbose: bool = False,
torch_with_mkl=True,
) -> None:
# make sure user isn't trying to do something silly on MacOS
if build_env.PLATFORM == "darwin" and device == Device.GPU:
Expand Down Expand Up @@ -186,6 +187,7 @@ def build_redis_ai(
build_tf=use_tf,
build_onnx=use_onnx,
verbose=verbose,
torch_with_mkl=torch_with_mkl,
)

if rai_builder.is_built:
Expand Down Expand Up @@ -414,6 +416,7 @@ def execute(
args.torch_dir,
args.libtensorflow_dir,
verbose=verbose,
torch_with_mkl=args.torch_with_mkl,
)
except (SetupError, BuildError) as e:
logger.error(str(e))
Expand Down Expand Up @@ -496,3 +499,17 @@ def configure_parser(parser: argparse.ArgumentParser) -> None:
default=False,
help="Build KeyDB instead of Redis",
)

parser.add_argument(
"--torch_with_mkl",
dest="torch_with_mkl",
action="store_true",
help="Build Torch with Intel MKL (if available)"
)
parser.add_argument(
"--no_torch_with_mkl",
dest="torch_with_mkl",
action="store_false",
help="Do not build Torch with Intel MKL"
)
parser.set_defaults(torch_with_mkl=True)
33 changes: 30 additions & 3 deletions smartsim/_core/_install/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import concurrent.futures
import enum
import fileinput
import itertools
import os
import platform
Expand All @@ -54,7 +55,7 @@

TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"]


_PathLike = t.TypeVar("_PathLike", Path, str, bytes)
_T = t.TypeVar("_T")
_U = t.TypeVar("_U")

Expand Down Expand Up @@ -410,6 +411,7 @@ def __init__(
build_onnx: bool = False,
jobs: int = 1,
verbose: bool = False,
torch_with_mkl: bool = True,
) -> None:
super().__init__(
build_env or {},
Expand All @@ -428,6 +430,9 @@ def __init__(
self.libtf_dir = libtf_dir
self.torch_dir = torch_dir

# extra configuration options
self.torch_with_mkl = torch_with_mkl

# Sanity checks
self._validate_platform()

Expand Down Expand Up @@ -517,8 +522,8 @@ def _get_deps_to_fetch_for(
# DLPack is always required
fetchable_deps: t.List[_RAIBuildDependency] = [_DLPackRepository("v0.5_RAI")]
if self.fetch_torch:
pt_dep = _choose_pt_variant(os_)
fetchable_deps.append(pt_dep(arch, device, "2.0.1"))
pt_dep = _choose_pt_variant(os_)(arch, device, "2.0.1", self.torch_with_mkl)
fetchable_deps.append(pt_dep)
if self.fetch_tf:
fetchable_deps.append(_TFArchive(os_, arch, device, "2.13.1"))
if self.fetch_onnx:
Expand Down Expand Up @@ -840,6 +845,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency):
architecture: Architecture
device: Device
version: str
torch_with_mkl: bool

@staticmethod
def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
Expand All @@ -854,9 +860,19 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]:
def __rai_dependency_name__(self) -> str:
return f"libtorch@{self.url}"

@staticmethod
def _patch_out_mkl(libtorch_root: Path) -> None:
_modify_source_files(
libtorch_root / "share/cmake/Caffe2/public/mkl.cmake",
r"find_package\(MKL QUIET\)",
"# find_package(MKL QUIET)",
)

def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path:
self.extract(target)
target = Path(target) / "libtorch"
if not self.torch_with_mkl:
self._patch_out_mkl(target)
if not target.is_dir():
raise BuildError("Failed to place RAI dependency: `libtorch`")
return target
Expand Down Expand Up @@ -1051,3 +1067,14 @@ def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]:
+ cmd[where:]
)
return cmd


def _modify_source_files(
files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str
) -> None:
compiled_regex = re.compile(regex)
patcher: t.Callable[[str], str] = lambda line: compiled_regex.sub(replacement, line)
with fileinput.input(files=files, encoding="utf-8", inplace=True) as f:
for line in f:
line = patcher(line)
print(line, end="")
36 changes: 34 additions & 2 deletions tests/install/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

import functools
import pathlib
import platform
import threading
import textwrap
import time

import pytest
Expand Down Expand Up @@ -370,3 +369,36 @@ def test_valid_platforms():
)
def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd):
assert build.config_git_command(plat, cmd) == expected_cmd


def test_modify_source_files(p_test_dir):
def make_text_blurb(food):
return textwrap.dedent(f"""\
My favorite food is {food}
{food} is an important part of a healthy breakfast
{food} {food} {food} {food}
This line should be unchanged!
--> {food} <--
""")

original_word = "SPAM"
mutated_word = "EGGS"

source_files = []
for i in range(3):
source_file = p_test_dir / f"test_{i}"
source_file.touch()
source_file.write_text(make_text_blurb(original_word))
source_files.append(source_file)
# Modify a single file
build._modify_source_files(source_files[0], original_word, mutated_word)
assert source_files[0].read_text() == make_text_blurb(mutated_word)
assert source_files[1].read_text() == make_text_blurb(original_word)
assert source_files[2].read_text() == make_text_blurb(original_word)

# Modify multiple files
build._modify_source_files(
(source_files[1], source_files[2]), original_word, mutated_word
)
assert source_files[1].read_text() == make_text_blurb(mutated_word)
assert source_files[2].read_text() == make_text_blurb(mutated_word)

0 comments on commit b3bc044

Please sign in to comment.