Skip to content

Commit

Permalink
Adds state to the tracker and raises exceptions if starting / stoppin…
Browse files Browse the repository at this point in the history
…g in the wrong state
  • Loading branch information
erikhuck committed Apr 15, 2024
1 parent ea63727 commit 0926894
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/gpu_tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import psutil
import subprocess as subp
import logging as log
import enum
import sys


class Tracker:
class State(enum.Enum):
NEW = 0
STARTED = 1
STOPPED = 2
"""
Runs a thread in the background that tracks the compute time, maximum RAM, and maximum GPU RAM usage within a context manager or explicit ``start()`` and ``stop()`` methods.
Calculated quantities are scaled depending on the units chosen for them (e.g. megabytes vs. gigabytes, hours vs. days, etc.).
Expand Down Expand Up @@ -78,6 +83,10 @@ def __init__(
self.max_gpu_ram = MaxGPURAM(unit=gpu_ram_unit, system_capacity=self._system_gpu_ram(measurement='total'))
self.cpu_utilization = CPUUtilization(system_core_count=psutil.cpu_count())
self.compute_time = ComputeTime(unit=time_unit)
self.state = Tracker.State.NEW

def __repr__(self):
return (f'State: {self.state.name}')

def _log_warning(self, warning: str):
if not self.disable_logs:
Expand Down Expand Up @@ -228,10 +237,19 @@ def _profile(self):
print(error)

def __enter__(self) -> Tracker:
if self.state == Tracker.State.STARTED:
raise RuntimeError('Cannot start tracking when tracking has already started.')
elif self.state == Tracker.State.STOPPED:
raise RuntimeError('Cannot start tracking when tracking has already stopped.')
self.state = Tracker.State.STARTED
self._thread.start()
return self

def __exit__(self, *_):
if self.state == Tracker.State.NEW:
raise RuntimeError('Cannot stop tracking when tracking has not started yet.')
if self.state == Tracker.State.STOPPED:
raise RuntimeError('Cannot stop tracking when tracking has already stopped.')
n_join_attempts = 0
while n_join_attempts < self.n_join_attempts:
self._stop_event.set()
Expand All @@ -248,6 +266,7 @@ def __exit__(self, *_):
if self.kill_if_join_fails:
log.warning('The thread failed to join and kill_if_join_fails is set. Exiting ...')
sys.exit(1)
self.state = Tracker.State.STOPPED

def start(self):
"""
Expand Down
25 changes: 25 additions & 0 deletions tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,28 @@ def test_validate_unit():
with pt.raises(ValueError) as error:
gput.Tracker(ram_unit='milibytes')
assert str(error.value) == '"milibytes" is not a valid memory unit. Valid values are bytes, gigabytes, kilobytes, megabytes, terabytes'


def test_state(mocker):
profile_mock = mocker.patch('gpu_tracker.tracker.Tracker._profile')
assert not profile_mock.called
mocker.patch('gpu_tracker.tracker.subp.check_output', side_effect=[b''])
tracker = gput.Tracker()
assert tracker.__repr__() == 'State: NEW'
with pt.raises(RuntimeError) as error:
tracker.stop()
assert str(error.value) == 'Cannot stop tracking when tracking has not started yet.'
tracker.start()
assert tracker.__repr__() == 'State: STARTED'
with pt.raises(RuntimeError) as error:
tracker.start()
assert str(error.value) == 'Cannot start tracking when tracking has already started.'
tracker.stop()
assert tracker.__repr__() == 'State: STOPPED'
with pt.raises(RuntimeError) as error:
tracker.start()
assert str(error.value) == 'Cannot start tracking when tracking has already stopped.'
with pt.raises(RuntimeError) as error:
tracker.stop()
assert str(error.value) == 'Cannot stop tracking when tracking has already stopped.'
assert profile_mock.called

0 comments on commit 0926894

Please sign in to comment.