Skip to content

Commit

Permalink
Optionally patch linker for LTO
Browse files Browse the repository at this point in the history
Add a test that succeeds linking LTO LinkableCode with Numba using this
new feature.
  • Loading branch information
gmarkall committed Apr 19, 2024
1 parent e5dafbd commit 019a002
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
6 changes: 4 additions & 2 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from functools import partial
from pynvjitlink.api import NvJitLinker, NvJitLinkError

import os
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
options.extend(additional_flags)

self._linker = NvJitLinker(*options)
self.lto = lto
self.options = options

@property
Expand Down Expand Up @@ -250,14 +252,14 @@ def new_patched_linker(
)


def patch_numba_linker():
def patch_numba_linker(*, lto=False):
if not _numba_version_ok:
msg = f"Cannot patch Numba: {_numba_error}"
raise RuntimeError(msg)

# Replace the built-in linker that uses the Driver API with our linker that
# uses nvJitLink
Linker.new = new_patched_linker
Linker.new = partial(new_patched_linker, lto=lto)

# Add linkable code objects to Numba's top-level API
cuda.Archive = Archive
Expand Down
35 changes: 28 additions & 7 deletions pynvjitlink/tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_numba_patching():
from numba.cuda.cudadrv.driver import Linker

patch_numba_linker()
assert Linker.new is new_patched_linker
assert Linker.new.func is new_patched_linker


def test_create():
Expand Down Expand Up @@ -133,12 +133,6 @@ def test_add_file_guess_ext_invalid_input(
"linkable_code_fatbin",
"linkable_code_object",
"linkable_code_ptx",
pytest.param(
"linkable_code_ltoir",
marks=pytest.mark.xfail(
reason=".ltoir file is actually an object and lto=True missing"
),
),
),
)
def test_jit_with_linkable_code(file, request):
Expand All @@ -157,6 +151,33 @@ def kernel(result):
assert result[0] == 3


@pytest.fixture
def numba_linking_with_lto():
"""
Patch the linker for LTO for the duration of the test.
Afterwards, restore the linker to whatever it was before.
"""
from numba.cuda.cudadrv.driver import Linker

old_new = Linker.new
patch_numba_linker(lto=True)
yield
Linker.new = old_new


def test_jit_with_linkable_code_lto(linkable_code_ltoir, numba_linking_with_lto):
sig = "uint32(uint32, uint32)"
add_from_numba = cuda.declare_device("add_from_numba", sig)

@cuda.jit(link=[linkable_code_ltoir])
def kernel(result):
result[0] = add_from_numba(1, 2)

result = cuda.device_array(1)
kernel[1, 1](result)
assert result[0] == 3


@pytest.mark.skipif(
not _numba_version_ok,
reason=f"Requires Numba == {required_numba_ver[0]}.{required_numba_ver[1]}",
Expand Down

0 comments on commit 019a002

Please sign in to comment.