Skip to content

Commit

Permalink
profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Feb 25, 2024
1 parent 2c00270 commit 16f35cc
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
2 changes: 2 additions & 0 deletions analog/monitor_util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .timer import FunctionTimer, Timer
from .profiler import memory_profiler
26 changes: 26 additions & 0 deletions analog/monitor_util/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import functools
from torch.profiler import profile, ProfilerActivity


def memory_profiler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
activities = [ProfilerActivity.CPU]
if device.type == "cuda":
activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, profile_memory=True) as prof:
result = func(*args, **kwargs)

print(
prof.key_averages().table(
sort_by="self_cuda_memory_usage"
if device.type == "cuda"
else "self_cpu_memory_usage"
)
)
return result

return wrapper
178 changes: 178 additions & 0 deletions analog/monitor_util/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import logging
import time
import functools

import torch


def get_gpu_memory(device_index=None):
return torch.cuda.memory_allocated(device_index)


def get_gpu_max_memory(device_index=None):
return torch.cuda.max_memory_allocated(device_index)


class FunctionTimer:
log = {}

@classmethod
def _wrap_function(cls, func, label, host_timer):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if host_timer:
return cls._host_timer_wrapper(func, label, *args, **kwargs)
else:
return cls._device_timer_wrapper(func, label, *args, **kwargs)

return wrapper

@classmethod
def _host_timer_wrapper(cls, func, label, *args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": end_time - start_time,
}
]
else:
cls.log[label].append(
{
"time_delta": end_time - start_time,
}
)
return result

@classmethod
def _device_timer_wrapper(cls, func, label, *args, **kwargs):
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
torch.cuda.current_stream().wait_event(end_event)
torch.cuda.synchronize()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
}
]
else:
cls.log[label].append(
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
}
)
return result

@classmethod
def timer(cls, label_or_func=None):
host_timer = getattr(
cls, "host_timer", False
) # Fallback to False if not defined

def decorator(func):
label = label_or_func if isinstance(label_or_func, str) else func.__name__
return cls._wrap_function(func, label, host_timer)

if callable(label_or_func):
return decorator(label_or_func)
return decorator

@classmethod
def get_log(cls):
return cls.log

@classmethod
def print_log(cls):
print(
"###########################################################################"
)
print(
"################################ TIMER LOG ################################"
)
header = f"{'Label':<50} | {'Total Time (sec)':>20}"
print(header)
print("-" * len(header))
for label, details in cls.log.items():
sum_time = 0
for log_entry in details:
time_delta = log_entry.get("time_delta", 0)
sum_time += time_delta
# truncate 47 letters if the label is longer than 50.
display_label = (label[:47] + "...") if len(label) > 50 else label
row = f"{display_label:<50} | {sum_time:>20.4f}"
print(row)


class HostFunctionTimer(FunctionTimer):
host_timer = True


class DeviceFunctionTimer(FunctionTimer):
if torch.cuda.is_available():
host_timer = False
else:
logging.warning(
"CUDA is not set, setting the monitor_util is set to host monitor_util."
)
host_timer = True


class Timer:
def __init__(self):
self.timers = {
"cpu": {},
"gpu": {},
}
self.timer_info = {} # synchronized.
self.is_synchronized = False

def start_timer(self, name, host_timer=False):
if host_timer:
if name in self.timers["cpu"]:
logging.warning(f"monitor_util for {name} already exist")
return
start_time = time.time()
self.timers["cpu"][name] = [start_time]
else:
if name in self.timers["gpu"]:
logging.warning(f"monitor_util for {name} already exist")
return
self.is_synchronized = False
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
self.timers["gpu"][name] = [start_event]

def stop_timer(self, name):
if name in self.timers["cpu"]:
end_time = time.time()
self.timers["cpu"][name].append(end_time)
if name in self.timers["gpu"]:
self.is_synchronized = False
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
self.timers["gpu"][name].append(end_event)

def _calculate_elapse_time(self):
for name, timer in self.timers["cpu"].items():
assert len(timer) == 2
self.timer_info[name] = (timer[1] - timer[0]) * 1000
if not self.is_synchronized:
for name, events in self.timers["gpu"].items():
assert len(events) == 2
torch.cuda.current_stream().wait_event(events[1])
torch.cuda.synchronize()
self.timer_info[name] = events[0].elapsed_time(events[1])
self.is_synchronized = True

def get_info(self):
if not self.is_synchronized:
self._calculate_elapse_time()
return self.timer_info

0 comments on commit 16f35cc

Please sign in to comment.