Skip to content

Commit

Permalink
Pass current machine's architecture to rapids-dependency-file-generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Oct 9, 2024
1 parent 340e9a5 commit 2bf78f1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
14 changes: 14 additions & 0 deletions rapids_build_backend/impls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import os
import platform
import re
import shutil
import subprocess
Expand Down Expand Up @@ -51,6 +52,18 @@ def _get_backend(build_backend):
)


@lru_cache
def _get_arch():
"""Get the arch of the current machine.
Returns
-------
str
The arch (e.g. "x86_64" or "aarch64")
"""
return platform.machine()


@lru_cache
def _get_cuda_version():
"""Get the CUDA suffix based on nvcc.
Expand Down Expand Up @@ -190,6 +203,7 @@ def _edit_pyproject(config):
matrix = _parse_matrix(config.matrix_entry) or dict(file_config.matrix)
if not config.disable_cuda:
matrix["cuda"] = [f"{cuda_version_major}.{cuda_version_minor}"]
matrix["arch"] = [_get_arch()]
rapids_dependency_file_generator.make_dependency_files(
parsed_config=parsed_config,
file_keys=[file_key],
Expand Down
30 changes: 23 additions & 7 deletions tests/test_impls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import os.path
import platform
from contextlib import contextmanager
from textwrap import dedent
from unittest.mock import Mock, patch
Expand All @@ -10,6 +11,7 @@
from rapids_build_backend.impls import (
_check_setup_py,
_edit_pyproject,
_get_arch,
_get_cuda_suffix,
_remove_rapidsai_from_config,
_write_git_commits,
Expand Down Expand Up @@ -88,8 +90,9 @@ def test_write_git_commits(
"cuda_version",
"cuda_suffix",
"cuda_python_requirement",
"matrix",
"arch",
"arch_requirement",
"matrix",
],
[
(
Expand All @@ -100,8 +103,9 @@ def test_write_git_commits(
("11", "5"),
"-cu11",
"cuda-python>=11.5,<11.6.dev0",
"",
"x86_64",
"some-x86-package",
"",
),
(
".",
Expand All @@ -111,8 +115,9 @@ def test_write_git_commits(
("11", "5"),
"-cu11",
"cuda-python>=11.5,<11.6.dev0",
"arch=aarch64",
"aarch64",
"some-arm-package",
"",
),
(
"python",
Expand All @@ -122,8 +127,9 @@ def test_write_git_commits(
("12", "1"),
"-cu12",
"cuda-python>=12.1,<12.2.dev0",
"",
"x86_64",
"some-x86-package",
"",
),
(
".",
Expand All @@ -133,8 +139,11 @@ def test_write_git_commits(
("11", "5"),
"-cu11",
None,
None, # Test the arch detection logic
"some-x86-package"
if platform.machine() == "x86_64"
else "some-arm-package",
"",
None,
),
(
".",
Expand All @@ -144,8 +153,11 @@ def test_write_git_commits(
None, # Ensure _get_cuda_version() isn't called and unpacked
"",
"cuda-python",
None, # Test the arch detection logic
"some-x86-package"
if platform.machine() == "x86_64"
else "some-arm-package",
"",
"some-x86-package",
),
],
)
Expand All @@ -158,8 +170,9 @@ def test_edit_pyproject(
cuda_version,
cuda_suffix,
cuda_python_requirement,
matrix,
arch,
arch_requirement,
matrix,
):
original_contents = dedent(
"""\
Expand Down Expand Up @@ -265,6 +278,9 @@ def test_edit_pyproject(
)

with patch(
"rapids_build_backend.impls._get_arch",
Mock(return_value=arch) if arch is not None else _get_arch,
), patch(
"rapids_build_backend.impls._get_cuda_version",
Mock(return_value=cuda_version),
), patch(
Expand Down

0 comments on commit 2bf78f1

Please sign in to comment.