From e2d51d121216057268b483645ffdaa8a494b9b1f Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Tue, 2 Apr 2024 13:47:38 -0700 Subject: [PATCH 1/8] Patch out the detection of MKL if requested --- smartsim/_core/_cli/build.py | 10 +++++++++ smartsim/_core/_install/builder.py | 33 ++++++++++++++++++++++++--- tests/install/test_builder.py | 36 ++++++++++++++++++++++++++++-- 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 08a1a6138..6423db8b5 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -139,6 +139,7 @@ def build_redis_ai( torch_dir: t.Union[str, Path, None] = None, libtf_dir: t.Union[str, Path, None] = None, verbose: bool = False, + torch_with_mkl=True, ) -> None: # make sure user isn't trying to do something silly on MacOS if build_env.PLATFORM == "darwin" and device == Device.GPU: @@ -186,6 +187,7 @@ def build_redis_ai( build_tf=use_tf, build_onnx=use_onnx, verbose=verbose, + torch_with_mkl=torch_with_mkl, ) if rai_builder.is_built: @@ -414,6 +416,7 @@ def execute( args.torch_dir, args.libtensorflow_dir, verbose=verbose, + torch_with_mkl=args.torch_with_mkl, ) except (SetupError, BuildError) as e: logger.error(str(e)) @@ -496,3 +499,10 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Build KeyDB instead of Redis", ) + + parser.add_argument( + "--torch_with_mkl", + action="store_true", + default=True, + help="Build torch with Intel MKL support", + ) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 47f12d044..6a89dd1f1 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -28,6 +28,7 @@ import concurrent.futures import enum +import fileinput import itertools import os import platform @@ -54,7 +55,7 @@ TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] - +_PathLike = t.TypeVar("_PathLike", Path, str, bytes) _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -410,6 +411,7 @@ def __init__( build_onnx: bool = False, jobs: int = 1, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: super().__init__( build_env or {}, @@ -428,6 +430,9 @@ def __init__( self.libtf_dir = libtf_dir self.torch_dir = torch_dir + # extra configuration options + self.torch_with_mkl = torch_with_mkl + # Sanity checks self._validate_platform() @@ -517,8 +522,8 @@ def _get_deps_to_fetch_for( # DLPack is always required fetchable_deps: t.List[_RAIBuildDependency] = [_DLPackRepository("v0.5_RAI")] if self.fetch_torch: - pt_dep = _choose_pt_variant(os_) - fetchable_deps.append(pt_dep(arch, device, "2.0.1")) + pt_dep = _choose_pt_variant(os_)(arch, device, "2.0.1", self.torch_with_mkl) + fetchable_deps.append(pt_dep) if self.fetch_tf: fetchable_deps.append(_TFArchive(os_, arch, device, "2.13.1")) if self.fetch_onnx: @@ -840,6 +845,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency): architecture: Architecture device: Device version: str + torch_with_mkl: bool @staticmethod def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: @@ -854,9 +860,19 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def __rai_dependency_name__(self) -> str: return f"libtorch@{self.url}" + @staticmethod + def _patch_out_mkl(libtorch_root: Path) -> None: + _modify_source_files( + libtorch_root / "share/cmake/Caffe2/public/mkl.cmake", + "find_package(MKL QUIET)", + "# find_package(MKL QUIET)", + ) + def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: self.extract(target) target = Path(target) / "libtorch" + if not self.torch_with_mkl: + self._patch_out_mkl(target) if not target.is_dir(): raise BuildError("Failed to place RAI dependency: `libtorch`") return target @@ -1051,3 +1067,14 @@ def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]: + cmd[where:] ) return cmd + + +def _modify_source_files( + files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str +) -> None: + compiled_regex = re.compile(regex) + patcher: t.Callable[[str], str] = lambda line: compiled_regex.sub(replacement, line) + with fileinput.input(files=files, encoding="utf-8", inplace=True) as f: + for line in f: + line = patcher(line) + print(line, end="") diff --git a/tests/install/test_builder.py b/tests/install/test_builder.py index c69a083d1..ce0439078 100644 --- a/tests/install/test_builder.py +++ b/tests/install/test_builder.py @@ -27,8 +27,7 @@ import functools import pathlib -import platform -import threading +import textwrap import time import pytest @@ -370,3 +369,36 @@ def test_valid_platforms(): ) def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd): assert build.config_git_command(plat, cmd) == expected_cmd + + +def test_modify_source_files(p_test_dir): + def make_text_blurb(food): + return textwrap.dedent(f"""\ + My favorite food is {food} + {food} is an important part of a healthy breakfast + {food} {food} {food} {food} + This line should be unchanged! + --> {food} <-- + """) + + original_word = "SPAM" + mutated_word = "EGGS" + + source_files = [] + for i in range(3): + source_file = p_test_dir / f"test_{i}" + source_file.touch() + source_file.write_text(make_text_blurb(original_word)) + source_files.append(source_file) + # Modify a single file + build._modify_source_files(source_files[0], original_word, mutated_word) + assert source_files[0].read_text() == make_text_blurb(mutated_word) + assert source_files[1].read_text() == make_text_blurb(original_word) + assert source_files[2].read_text() == make_text_blurb(original_word) + + # Modify multiple files + build._modify_source_files( + (source_files[1], source_files[2]), original_word, mutated_word + ) + assert source_files[1].read_text() == make_text_blurb(mutated_word) + assert source_files[2].read_text() == make_text_blurb(mutated_word) From a376eac2c729f5e791a10d1241f822c0a3f2f5a4 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Tue, 2 Apr 2024 13:47:38 -0700 Subject: [PATCH 2/8] Patch out the detection of MKL if requested --- doc/changelog.rst | 7 ++++++ smartsim/_core/_cli/build.py | 17 ++++++++++++++ smartsim/_core/_install/builder.py | 33 ++++++++++++++++++++++++--- tests/install/test_builder.py | 36 ++++++++++++++++++++++++++++-- 4 files changed, 88 insertions(+), 5 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index c72513d04..cf102a9b7 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -18,6 +18,7 @@ To be released at some future point in time Description +- Add option to build Torch backend without the Intel Math Kernel Library - Promote device options to an Enum - Update telemetry monitor, add telemetry collectors - Add method to specify node features for a Slurm job @@ -34,6 +35,11 @@ Description Detailed Notes +- Add an option to smart build "--torch_with_mkl"/"--no_torch_with_mkl" to + prevent Torch from trying to link in the Intel Math Kernel Library. This + is needed because on machines that have the Intel compilers installed, the + Torch will unconditionally try to link in this library, however fails + because the linking flags are incorrect. (SmartSim-PR538_) - Promote devices to a dedicated Enum type throughout the SmartSim code base. - Update the telemetry monitor to enable retrieval of metrics on a scheduled interval. Switch basic experiment tracking telemetry to default to on. Add @@ -74,6 +80,7 @@ Detailed Notes - Remove previously deprecated behavior present in test suite on machines with Slurm and Open MPI. (SmartSim-PR520_) +.. _SmartSim-PR538: https://github.com/CrayLabs/SmartSim/pull/538 .. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498 .. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460 .. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512 diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 08a1a6138..f8c30a0cd 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -139,6 +139,7 @@ def build_redis_ai( torch_dir: t.Union[str, Path, None] = None, libtf_dir: t.Union[str, Path, None] = None, verbose: bool = False, + torch_with_mkl=True, ) -> None: # make sure user isn't trying to do something silly on MacOS if build_env.PLATFORM == "darwin" and device == Device.GPU: @@ -186,6 +187,7 @@ def build_redis_ai( build_tf=use_tf, build_onnx=use_onnx, verbose=verbose, + torch_with_mkl=torch_with_mkl, ) if rai_builder.is_built: @@ -414,6 +416,7 @@ def execute( args.torch_dir, args.libtensorflow_dir, verbose=verbose, + torch_with_mkl=args.torch_with_mkl, ) except (SetupError, BuildError) as e: logger.error(str(e)) @@ -496,3 +499,17 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Build KeyDB instead of Redis", ) + + parser.add_argument( + "--torch_with_mkl", + dest="torch_with_mkl", + action="store_true", + help="Build Torch with Intel MKL (if available)" + ) + parser.add_argument( + "--no_torch_with_mkl", + dest="torch_with_mkl", + action="store_false", + help="Do not build Torch with Intel MKL" + ) + parser.set_defaults(torch_with_mkl=True) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 47f12d044..922b68d8c 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -28,6 +28,7 @@ import concurrent.futures import enum +import fileinput import itertools import os import platform @@ -54,7 +55,7 @@ TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] - +_PathLike = t.TypeVar("_PathLike", Path, str, bytes) _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -410,6 +411,7 @@ def __init__( build_onnx: bool = False, jobs: int = 1, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: super().__init__( build_env or {}, @@ -428,6 +430,9 @@ def __init__( self.libtf_dir = libtf_dir self.torch_dir = torch_dir + # extra configuration options + self.torch_with_mkl = torch_with_mkl + # Sanity checks self._validate_platform() @@ -517,8 +522,8 @@ def _get_deps_to_fetch_for( # DLPack is always required fetchable_deps: t.List[_RAIBuildDependency] = [_DLPackRepository("v0.5_RAI")] if self.fetch_torch: - pt_dep = _choose_pt_variant(os_) - fetchable_deps.append(pt_dep(arch, device, "2.0.1")) + pt_dep = _choose_pt_variant(os_)(arch, device, "2.0.1", self.torch_with_mkl) + fetchable_deps.append(pt_dep) if self.fetch_tf: fetchable_deps.append(_TFArchive(os_, arch, device, "2.13.1")) if self.fetch_onnx: @@ -840,6 +845,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency): architecture: Architecture device: Device version: str + torch_with_mkl: bool @staticmethod def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: @@ -854,9 +860,19 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def __rai_dependency_name__(self) -> str: return f"libtorch@{self.url}" + @staticmethod + def _patch_out_mkl(libtorch_root: Path) -> None: + _modify_source_files( + libtorch_root / "share/cmake/Caffe2/public/mkl.cmake", + r"find_package\(MKL QUIET\)", + "# find_package(MKL QUIET)", + ) + def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: self.extract(target) target = Path(target) / "libtorch" + if not self.torch_with_mkl: + self._patch_out_mkl(target) if not target.is_dir(): raise BuildError("Failed to place RAI dependency: `libtorch`") return target @@ -1051,3 +1067,14 @@ def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]: + cmd[where:] ) return cmd + + +def _modify_source_files( + files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str +) -> None: + compiled_regex = re.compile(regex) + patcher: t.Callable[[str], str] = lambda line: compiled_regex.sub(replacement, line) + with fileinput.input(files=files, encoding="utf-8", inplace=True) as f: + for line in f: + line = patcher(line) + print(line, end="") diff --git a/tests/install/test_builder.py b/tests/install/test_builder.py index c69a083d1..ce0439078 100644 --- a/tests/install/test_builder.py +++ b/tests/install/test_builder.py @@ -27,8 +27,7 @@ import functools import pathlib -import platform -import threading +import textwrap import time import pytest @@ -370,3 +369,36 @@ def test_valid_platforms(): ) def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd): assert build.config_git_command(plat, cmd) == expected_cmd + + +def test_modify_source_files(p_test_dir): + def make_text_blurb(food): + return textwrap.dedent(f"""\ + My favorite food is {food} + {food} is an important part of a healthy breakfast + {food} {food} {food} {food} + This line should be unchanged! + --> {food} <-- + """) + + original_word = "SPAM" + mutated_word = "EGGS" + + source_files = [] + for i in range(3): + source_file = p_test_dir / f"test_{i}" + source_file.touch() + source_file.write_text(make_text_blurb(original_word)) + source_files.append(source_file) + # Modify a single file + build._modify_source_files(source_files[0], original_word, mutated_word) + assert source_files[0].read_text() == make_text_blurb(mutated_word) + assert source_files[1].read_text() == make_text_blurb(original_word) + assert source_files[2].read_text() == make_text_blurb(original_word) + + # Modify multiple files + build._modify_source_files( + (source_files[1], source_files[2]), original_word, mutated_word + ) + assert source_files[1].read_text() == make_text_blurb(mutated_word) + assert source_files[2].read_text() == make_text_blurb(mutated_word) From 95d1886bad1ef0ae3bfe13a8e66cc508bbe72ea7 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Tue, 2 Apr 2024 16:53:46 -0700 Subject: [PATCH 3/8] Add typehint --- smartsim/_core/_cli/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 6423db8b5..77974ae08 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -139,7 +139,7 @@ def build_redis_ai( torch_dir: t.Union[str, Path, None] = None, libtf_dir: t.Union[str, Path, None] = None, verbose: bool = False, - torch_with_mkl=True, + torch_with_mkl: bool = True, ) -> None: # make sure user isn't trying to do something silly on MacOS if build_env.PLATFORM == "darwin" and device == Device.GPU: From daa975dfdb09d025b10209f2fedd8c873e65627b Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Wed, 3 Apr 2024 15:22:06 -0700 Subject: [PATCH 4/8] Response to reviewer feedback --- smartsim/_core/_cli/build.py | 16 +--------------- smartsim/_core/_install/builder.py | 9 ++++----- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 6402d5ca7..70b6a5479 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -500,24 +500,10 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: help="Build KeyDB instead of Redis", ) - parser.add_argument( -<<<<<<< HEAD - "--torch_with_mkl", - action="store_true", - default=True, - help="Build torch with Intel MKL support", - ) -======= - "--torch_with_mkl", - dest="torch_with_mkl", - action="store_true", - help="Build Torch with Intel MKL (if available)" - ) parser.add_argument( "--no_torch_with_mkl", dest="torch_with_mkl", action="store_false", - help="Do not build Torch with Intel MKL" + help="Do not build Torch with Intel MKL", ) parser.set_defaults(torch_with_mkl=True) ->>>>>>> a376eac2c729f5e791a10d1241f822c0a3f2f5a4 diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 922b68d8c..cfe01af95 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -55,7 +55,7 @@ TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] -_PathLike = t.TypeVar("_PathLike", Path, str, bytes) +_PathLike = t.Union[Path, str, bytes] _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -845,7 +845,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency): architecture: Architecture device: Device version: str - torch_with_mkl: bool + with_mkl: bool @staticmethod def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: @@ -871,7 +871,7 @@ def _patch_out_mkl(libtorch_root: Path) -> None: def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: self.extract(target) target = Path(target) / "libtorch" - if not self.torch_with_mkl: + if not self.with_mkl: self._patch_out_mkl(target) if not target.is_dir(): raise BuildError("Failed to place RAI dependency: `libtorch`") @@ -1073,8 +1073,7 @@ def _modify_source_files( files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str ) -> None: compiled_regex = re.compile(regex) - patcher: t.Callable[[str], str] = lambda line: compiled_regex.sub(replacement, line) with fileinput.input(files=files, encoding="utf-8", inplace=True) as f: for line in f: - line = patcher(line) + line = compiled_regex.sub(replacement, line) print(line, end="") From 9ce12ffa16e0afed43260c7a6f07987ab5e642f7 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Wed, 3 Apr 2024 19:20:28 -0700 Subject: [PATCH 5/8] dust lint --- smartsim/_core/_install/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index cfe01af95..a2167d215 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -1073,7 +1073,7 @@ def _modify_source_files( files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str ) -> None: compiled_regex = re.compile(regex) - with fileinput.input(files=files, encoding="utf-8", inplace=True) as f: - for line in f: + with fileinput.input(files=files, encoding="utf-8", inplace=True) as handles: + for line in handles: line = compiled_regex.sub(replacement, line) print(line, end="") From a27344be52ab89eeb05d44c760b842ad20f93dd0 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Thu, 4 Apr 2024 09:52:36 -0700 Subject: [PATCH 6/8] Remove encoding spec for older fileinput --- smartsim/_core/_install/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index a2167d215..3818faf8b 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -1073,7 +1073,7 @@ def _modify_source_files( files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str ) -> None: compiled_regex = re.compile(regex) - with fileinput.input(files=files, encoding="utf-8", inplace=True) as handles: + with fileinput.input(files=files, inplace=True) as handles: for line in handles: line = compiled_regex.sub(replacement, line) print(line, end="") From cab008b67cc9bf7c7fbff3705e48bdbfebe318bb Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Thu, 4 Apr 2024 14:02:33 -0700 Subject: [PATCH 7/8] Refactor type hints --- smartsim/_core/_cli/build.py | 1 - smartsim/_core/_install/builder.py | 40 ++++++++++++++---------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 70b6a5479..ab982ac1b 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -506,4 +506,3 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: action="store_false", help="Do not build Torch with Intel MKL", ) - parser.set_defaults(torch_with_mkl=True) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 3818faf8b..d0dbc5a6a 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -54,8 +54,7 @@ # TODO: check cmake version and use system if possible to avoid conflicts TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] - -_PathLike = t.Union[Path, str, bytes] +_PathLike = t.Union[str, "os.PathLike[str]"] _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -370,7 +369,7 @@ class _RAIBuildDependency(ABC): def __rai_dependency_name__(self) -> str: ... @abstractmethod - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: ... + def __place_for_rai__(self, target: _PathLike) -> Path: ... @staticmethod @abstractmethod @@ -378,7 +377,7 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def _place_rai_dep_at( - target: t.Union[str, "os.PathLike[str]"], verbose: bool + target: _PathLike, verbose: bool ) -> t.Callable[[_RAIBuildDependency], Path]: def _place(dep: _RAIBuildDependency) -> Path: if verbose: @@ -760,7 +759,7 @@ def url(self) -> str: ... class _WebGitRepository(_WebLocation): def clone( self, - target: t.Union[str, "os.PathLike[str]"], + target: _PathLike, depth: t.Optional[int] = None, branch: t.Optional[str] = None, ) -> None: @@ -790,7 +789,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"dlpack@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "dlpack" self.clone(target, branch=self.version, depth=1) if not target.is_dir(): @@ -804,7 +803,7 @@ def name(self) -> str: _, name = self.url.rsplit("/", 1) return name - def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def download(self, target: _PathLike) -> Path: target = Path(target) if target.is_dir(): target = target / self.name @@ -814,28 +813,22 @@ def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: class _ExtractableWebArchive(_WebArchive, ABC): @abstractmethod - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: ... + def _extract_download(self, download_path: Path, target: _PathLike) -> None: ... - def extract(self, target: t.Union[str, "os.PathLike[str]"]) -> None: + def extract(self, target: _PathLike) -> None: with tempfile.TemporaryDirectory() as tmp_dir: arch_path = self.download(tmp_dir) self._extract_download(arch_path, target) class _WebTGZ(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with tarfile.open(download_path, "r") as tgz_file: tgz_file.extractall(target) class _WebZip(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with zipfile.ZipFile(download_path, "r") as zip_file: zip_file.extractall(target) @@ -868,11 +861,14 @@ def _patch_out_mkl(libtorch_root: Path) -> None: "# find_package(MKL QUIET)", ) - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def extract(self, target: _PathLike) -> None: + super().extract(target) + if not self.with_mkl: + self._patch_out_mkl(Path(target)) + + def __place_for_rai__(self, target: _PathLike) -> Path: self.extract(target) target = Path(target) / "libtorch" - if not self.with_mkl: - self._patch_out_mkl(target) if not target.is_dir(): raise BuildError("Failed to place RAI dependency: `libtorch`") return target @@ -980,7 +976,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"libtensorflow@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "libtensorflow" target.mkdir() self.extract(target) @@ -1026,7 +1022,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"onnxruntime@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target).resolve() / "onnxruntime" self.extract(target) try: From f5a298c0dfd02f298a0dba4c3baf866c34f94a69 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Thu, 4 Apr 2024 15:35:58 -0700 Subject: [PATCH 8/8] Fix cranky tests --- tests/install/test_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/install/test_builder.py b/tests/install/test_builder.py index ce0439078..feaf7e54f 100644 --- a/tests/install/test_builder.py +++ b/tests/install/test_builder.py @@ -253,13 +253,13 @@ def test_PTArchiveMacOSX_url(): pt_version = RAI_VERSIONS.torch pt_linux_cpu = build._PTArchiveLinux( - build.Architecture.X64, build.Device.CPU, pt_version + build.Architecture.X64, build.Device.CPU, pt_version, False ) x64_prefix = "https://download.pytorch.org/libtorch/" assert x64_prefix in pt_linux_cpu.url pt_macosx_cpu = build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.CPU, pt_version + build.Architecture.ARM64, build.Device.CPU, pt_version, False ) arm64_prefix = "https://github.com/CrayLabs/ml_lib_builder/releases/download/" assert arm64_prefix in pt_macosx_cpu.url @@ -268,7 +268,7 @@ def test_PTArchiveMacOSX_url(): def test_PTArchiveMacOSX_gpu_error(): with pytest.raises(build.BuildError, match="support GPU on Mac OSX"): build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch + build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch, False ).url