-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Timer function wrapper and instant timer #90
Open
eatpk
wants to merge
10
commits into
main
Choose a base branch
from
timer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
575aedb
Timer function wrapper and instant timer
eatpk f6e2981
Timer function wrapper and instant timer
eatpk 45ea001
comments
eatpk 60c3696
fix
eatpk 81ce797
requirements
eatpk b1d86d3
fix
eatpk a7db7a9
profiler
eatpk eca583d
profiler
eatpk ff1ad6f
profiler
eatpk 8326b59
black test
eatpk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .timer import FunctionTimer, Timer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import logging | ||
import time | ||
import functools | ||
|
||
import torch | ||
import psutil | ||
import os | ||
|
||
|
||
def get_gpu_memory(device_index=None): | ||
return torch.cuda.memory_allocated(device_index) | ||
|
||
|
||
def get_gpu_max_memory(device_index=None): | ||
return torch.cuda.max_memory_allocated(device_index) | ||
|
||
|
||
def get_host_memory(): | ||
process = psutil.Process(os.getpid()) | ||
return process.memory_info().rss | ||
|
||
|
||
def get_cpu_swap_memory(): | ||
return psutil.swap_memory().used | ||
|
||
|
||
class FunctionTimer: | ||
log = {} | ||
|
||
@classmethod | ||
def _wrap_function(cls, func, label, host_timer): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
if host_timer: | ||
return cls._host_timer_wrapper(func, label, *args, **kwargs) | ||
else: | ||
return cls._device_timer_wrapper(func, label, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
@classmethod | ||
def _host_timer_wrapper(cls, func, label, *args, **kwargs): | ||
before_memory = get_host_memory() | ||
start_time = time.time() | ||
result = func(*args, **kwargs) | ||
end_time = time.time() | ||
after_memory = get_host_memory() | ||
if label not in cls.log: | ||
cls.log[label] = [ | ||
{ | ||
"time_delta": end_time - start_time, | ||
"memory_delta": (before_memory - after_memory) >> 20, | ||
} | ||
] | ||
else: | ||
cls.log[label].append( | ||
{ | ||
"time_delta": end_time - start_time, | ||
"memory_delta": (before_memory - after_memory) >> 20, | ||
} | ||
) | ||
return result | ||
|
||
@classmethod | ||
def _device_timer_wrapper(cls, func, label, *args, **kwargs): | ||
before_memory = get_gpu_memory() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
result = func(*args, **kwargs) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
end_event.record() | ||
after_memory = get_gpu_memory() | ||
torch.cuda.current_stream().wait_event(end_event) | ||
torch.cuda.synchronize() | ||
if label not in cls.log: | ||
cls.log[label] = [ | ||
{ | ||
"time_delta": start_event.elapsed_time(end_event) | ||
/ 1000, # turn to seconds | ||
"memory_delta": (before_memory - after_memory) >> 20, | ||
} | ||
] | ||
else: | ||
cls.log[label].append( | ||
{ | ||
"time_delta": start_event.elapsed_time(end_event) | ||
/ 1000, # turn to seconds | ||
"memory_delta": (before_memory - after_memory) >> 20, | ||
} | ||
) | ||
return result | ||
|
||
@classmethod | ||
def timer(cls, label_or_func=None): | ||
host_timer = getattr( | ||
cls, "host_timer", False | ||
) # Fallback to False if not defined | ||
|
||
def decorator(func): | ||
label = label_or_func if isinstance(label_or_func, str) else func.__name__ | ||
return cls._wrap_function(func, label, host_timer) | ||
|
||
if callable(label_or_func): | ||
return decorator(label_or_func) | ||
return decorator | ||
|
||
@classmethod | ||
def get_log(cls): | ||
return cls.log | ||
|
||
@classmethod | ||
def print_log(cls): | ||
print("Function Timer Logs:") | ||
for label, details in cls.log.items(): | ||
print(f" {label}:") | ||
sum_time = 0 | ||
for log in details: | ||
for key, value in log.items(): | ||
if key == "time_delta": | ||
sum_time += value | ||
print(f" operation costs {sum_time} seconds") | ||
|
||
|
||
class HostFunctionTimer(FunctionTimer): | ||
host_timer = True | ||
|
||
|
||
class DeviceFunctionTimer(FunctionTimer): | ||
if torch.cuda.is_available(): | ||
host_timer = False | ||
else: | ||
logging.warning("CUDA is not set, setting the timer is set to host timer.") | ||
host_timer = True | ||
|
||
|
||
class Timer: | ||
def __init__(self): | ||
self.timers = { | ||
"cpu": {}, | ||
"gpu": {}, | ||
} | ||
self.timer_info = {} # synchronized. | ||
self.is_synchronized = False | ||
|
||
def start_timer(self, name, host_timer=False): | ||
if host_timer: | ||
if name in self.timers["cpu"]: | ||
logging.warning(f"timer for {name} already exist") | ||
return | ||
start_time = time.time() | ||
self.timers["cpu"][name] = [start_time] | ||
else: | ||
if name in self.timers["gpu"]: | ||
logging.warning(f"timer for {name} already exist") | ||
return | ||
self.is_synchronized = False | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
self.timers["gpu"][name] = [start_event] | ||
|
||
def stop_timer(self, name): | ||
if name in self.timers["cpu"]: | ||
end_time = time.time() | ||
self.timers["cpu"][name].append(end_time) | ||
if name in self.timers["gpu"]: | ||
self.is_synchronized = False | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
end_event.record() | ||
self.timers["gpu"][name].append(end_event) | ||
|
||
def _calculate_elapse_time(self): | ||
for name, timer in self.timers["cpu"].items(): | ||
assert len(timer) == 2 | ||
self.timer_info[name] = (timer[1] - timer[0]) * 1000 | ||
if not self.is_synchronized: | ||
for name, events in self.timers["gpu"].items(): | ||
assert len(events) == 2 | ||
torch.cuda.current_stream().wait_event(events[1]) | ||
torch.cuda.synchronize() | ||
self.timer_info[name] = events[0].elapsed_time(events[1]) | ||
self.is_synchronized = True | ||
|
||
def get_info(self): | ||
if not self.is_synchronized: | ||
self._calculate_elapse_time() | ||
return self.timer_info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,5 @@ numpy | |
pandas | ||
torch | ||
einops | ||
|
||
psutil | ||
pyyaml | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can just remove the
psutil
, since this is for the memory monitoring.. and it seems like the memory monitoring is not really doing anything at this point.(calculating the memory difference of before & after) We may need to do this via profiling as mentioned beforeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can remove psutil for now. We can add it back once there is a request for CPU memory tracking from the users.