Skip to content

Commit

Permalink
made the numba type for the ProgressBar class accessible for use in t…
Browse files Browse the repository at this point in the history
…yped signatures
  • Loading branch information
Felix Igelbrink committed May 22, 2023
1 parent cb70367 commit bd00fc4
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 15 deletions.
15 changes: 15 additions & 0 deletions examples/clock_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sleep import clock, usleep

import numba as nb

@nb.njit(nogil=True)
def numba_clock():
c1 = clock()
print("c1", c1)
usleep(1000000)
c2 = clock()
print("c2", c2)
print("time", c2-c1)

if __name__ == "__main__":
numba_clock()
19 changes: 19 additions & 0 deletions examples/example_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# example code to use the progressbar with an explicit signature

from sleep import usleep
import numba as nb
from numba_progress import ProgressBar, ProgressBarType


@nb.njit(nb.void(nb.uint64, nb.uint64, ProgressBarType), nogil=True)
def numba_sleeper(num_iterations, sleep_us, progress_hook):
for i in range(num_iterations):
usleep(sleep_us)
progress_hook.update(1)


if __name__ == "__main__":
num_iterations = 30
sleep_time_us = 250_000
with ProgressBar(total=num_iterations, ncols=80) as numba_progress:
numba_sleeper(num_iterations, sleep_time_us, numba_progress)
16 changes: 15 additions & 1 deletion examples/sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,18 @@ def usleep(usec):
def usleep_impl(usec):
func_usleep(usec)

return usleep_impl
return usleep_impl

@nb.generated_jit(nogil=True, nopython=True)
def clock():
import ctypes
libc = ctypes.CDLL('libc.so.6')
libc.clock.argtypes = ()
libc.clock.restype = ctypes.c_ulong
func_clock = libc.clock
CLOCKS_PER_SEC = 1_000_000

def clock_impl():
return func_clock() / CLOCKS_PER_SEC

return clock_impl
2 changes: 1 addition & 1 deletion numba_progress/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .progress import ProgressBar
from .progress import *
from ._version import __version__
2 changes: 1 addition & 1 deletion numba_progress/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2"
__version__ = "1.0.0"
25 changes: 13 additions & 12 deletions numba_progress/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numba.core import cgutils
from numba.core.boxing import unbox_array

__all__ = ['ProgressBar']
__all__ = ['ProgressBar', 'ProgressBarType']

def is_notebook():
"""Determine if we're running within an IPython kernel
Expand Down Expand Up @@ -119,23 +119,24 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# Numba Native Implementation for the ProgressBar Class

class ProgressBarType(types.Type):
class ProgressBarTypeImpl(types.Type):
def __init__(self):
super().__init__(name='ProgressBar')


progressbar_type = ProgressBarType()
# This is the numba type representation of the ProgressBar class to be used in signatures
ProgressBarType = ProgressBarTypeImpl()


@typeof_impl.register(ProgressBar)
def typeof_index(val, c):
return progressbar_type
return ProgressBarType


as_numba_type.register(ProgressBar, progressbar_type)
as_numba_type.register(ProgressBar, ProgressBarType)


@register_model(ProgressBarType)
@register_model(ProgressBarTypeImpl)
class ProgressBarModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
Expand All @@ -145,17 +146,17 @@ def __init__(self, dmm, fe_type):


# make the hook attribute accessible
make_attribute_wrapper(ProgressBarType, 'hook', 'hook')
make_attribute_wrapper(ProgressBarTypeImpl, 'hook', 'hook')


@overload_attribute(ProgressBarType, 'value')
@overload_attribute(ProgressBarTypeImpl, 'value')
def get_value(progress_bar):
def getter(progress_bar):
return progress_bar.hook[0]
return getter


@unbox(ProgressBarType)
@unbox(ProgressBarTypeImpl)
def unbox_progressbar(typ, obj, c):
"""
Convert a ProgressBar to it's native representation (proxy object)
Expand All @@ -168,18 +169,18 @@ def unbox_progressbar(typ, obj, c):
return NativeValue(progress_bar._getvalue(), is_error=is_error)


@box(ProgressBarType)
@box(ProgressBarTypeImpl)
def box_progressbar(typ, val, c):
raise TypeError("Native representation of ProgressBar cannot be converted back to a python object "
"as it contains internal python state.")


@overload_method(ProgressBarType, "update", jit_options={"nogil": True})
@overload_method(ProgressBarTypeImpl, "update", jit_options={"nogil": True})
def _ol_update(self, n=1):
"""
Numpy implementation of the update method.
"""
if isinstance(self, ProgressBarType):
if isinstance(self, ProgressBarTypeImpl):
def _update_impl(self, n=1):
atomic_add(self.hook, 0, n)
return _update_impl
Expand Down

0 comments on commit bd00fc4

Please sign in to comment.