Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timer function wrapper and instant timer #90

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from analog.lora import LoRAHandler
from analog.lora.utils import is_lora
from analog.state import AnaLogState
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import (
get_logger,
get_rank,
Expand Down Expand Up @@ -237,6 +238,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
self.logger.update()

@DeviceFunctionTimer.timer
def build_log_dataset(self):
"""
Constructs the log dataset from the stored logs. This dataset can then be used
Expand All @@ -249,6 +251,7 @@ def build_log_dataset(self):
log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config)
return log_dataset

@DeviceFunctionTimer.timer
def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
):
Expand Down
3 changes: 3 additions & 0 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import einsum, rearrange, reduce
from analog.config import InfluenceConfig
from analog.state import AnaLogState
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import get_logger, nested_dict
from analog.analysis.utils import synchronize_device

Expand All @@ -24,6 +25,7 @@ def __init__(self, config: InfluenceConfig, state: AnaLogState):
self.influence_scores = pd.DataFrame()
self.flatten = config.flatten

@DeviceFunctionTimer.timer
@torch.no_grad()
def precondition(
self,
Expand Down Expand Up @@ -212,6 +214,7 @@ def flatten_log(self, src):
to_cat.append(log.view(bsz, -1))
return torch.cat(to_cat, dim=1)

@DeviceFunctionTimer.timer
def compute_influence_all(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
Expand Down
3 changes: 3 additions & 0 deletions analog/logging/log_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import torch

from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import nested_dict, to_numpy, get_rank
from analog.logging.mmap import MemoryMapHandler

Expand All @@ -21,6 +22,7 @@ def __init__(self, config, state):
self.buffer = nested_dict()
self.buffer_size = 0

@DeviceFunctionTimer.timer
def buffer_write(self, binfo):
"""
Add log state on exit.
Expand Down Expand Up @@ -85,6 +87,7 @@ def _flush_serialized(self, log_dir) -> str:
del buffer_list
return log_dir

@DeviceFunctionTimer.timer
def flush(self) -> None:
"""
For the DefaultHandler, there's no batch operation needed since each add operation writes to the file.
Expand Down
7 changes: 7 additions & 0 deletions analog/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from analog.logging.option import LogOption
from analog.logging.log_saver import LogSaver
from analog.logging.utils import compute_per_sample_gradient
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import get_logger


Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.grad_hooks = []
self.tensor_hooks = []

@DeviceFunctionTimer.timer
def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):
"""
Add log state on exit.
Expand All @@ -59,6 +61,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

return log

@DeviceFunctionTimer.timer
def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
Expand All @@ -82,6 +85,7 @@ def update(self):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

@DeviceFunctionTimer.timer
def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
) -> None:
Expand Down Expand Up @@ -131,6 +135,7 @@ def _forward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _backward_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -172,6 +177,7 @@ def _backward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _grad_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -270,6 +276,7 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def register_all_module_hooks(self) -> None:
"""
Register all module hooks.
Expand Down
2 changes: 2 additions & 0 deletions analog/monitor_util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .timer import FunctionTimer, Timer
from .profiler import memory_profiler
28 changes: 28 additions & 0 deletions analog/monitor_util/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import functools
from torch.profiler import profile, ProfilerActivity


def memory_profiler(func):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addding this to the example codes will be added in the next PR.

@functools.wraps(func)
def wrapper(*args, **kwargs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
activities = [ProfilerActivity.CPU]
if device.type == "cuda":
activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, profile_memory=True) as prof:
result = func(*args, **kwargs)

print(
prof.key_averages().table(
sort_by=(
"self_cuda_memory_usage"
if device.type == "cuda"
else "self_cpu_memory_usage"
)
)
)
return result

return wrapper
178 changes: 178 additions & 0 deletions analog/monitor_util/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import logging
import time
import functools

import torch


def get_gpu_memory(device_index=None):
return torch.cuda.memory_allocated(device_index)


def get_gpu_max_memory(device_index=None):
return torch.cuda.max_memory_allocated(device_index)


class FunctionTimer:
log = {}

@classmethod
def _wrap_function(cls, func, label, host_timer):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if host_timer:
return cls._host_timer_wrapper(func, label, *args, **kwargs)
else:
return cls._device_timer_wrapper(func, label, *args, **kwargs)

return wrapper

@classmethod
def _host_timer_wrapper(cls, func, label, *args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": end_time - start_time,
}
]
else:
cls.log[label].append(
{
"time_delta": end_time - start_time,
}
)
return result

@classmethod
def _device_timer_wrapper(cls, func, label, *args, **kwargs):
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
torch.cuda.current_stream().wait_event(end_event)
torch.cuda.synchronize()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
}
]
else:
cls.log[label].append(
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
}
)
return result

@classmethod
def timer(cls, label_or_func=None):
host_timer = getattr(
cls, "host_timer", False
) # Fallback to False if not defined

def decorator(func):
label = label_or_func if isinstance(label_or_func, str) else func.__name__
return cls._wrap_function(func, label, host_timer)

if callable(label_or_func):
return decorator(label_or_func)
return decorator

@classmethod
def get_log(cls):
return cls.log

@classmethod
def print_log(cls):
print(
"###########################################################################"
)
print(
"################################ TIMER LOG ################################"
)
header = f"{'Label':<50} | {'Total Time (sec)':>20}"
print(header)
print("-" * len(header))
for label, details in cls.log.items():
sum_time = 0
for log_entry in details:
time_delta = log_entry.get("time_delta", 0)
sum_time += time_delta
# truncate 47 letters if the label is longer than 50.
display_label = (label[:47] + "...") if len(label) > 50 else label
row = f"{display_label:<50} | {sum_time:>20.4f}"
print(row)


class HostFunctionTimer(FunctionTimer):
host_timer = True


class DeviceFunctionTimer(FunctionTimer):
if torch.cuda.is_available():
host_timer = False
else:
logging.warning(
"CUDA is not set, setting the monitor_util is set to host monitor_util."
)
host_timer = True


class Timer:
def __init__(self):
self.timers = {
"cpu": {},
"gpu": {},
}
self.timer_info = {} # synchronized.
self.is_synchronized = False

def start_timer(self, name, host_timer=False):
if host_timer:
if name in self.timers["cpu"]:
logging.warning(f"monitor_util for {name} already exist")
return
start_time = time.time()
self.timers["cpu"][name] = [start_time]
else:
if name in self.timers["gpu"]:
logging.warning(f"monitor_util for {name} already exist")
return
self.is_synchronized = False
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
self.timers["gpu"][name] = [start_event]

def stop_timer(self, name):
if name in self.timers["cpu"]:
end_time = time.time()
self.timers["cpu"][name].append(end_time)
if name in self.timers["gpu"]:
self.is_synchronized = False
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
self.timers["gpu"][name].append(end_event)

def _calculate_elapse_time(self):
for name, timer in self.timers["cpu"].items():
assert len(timer) == 2
self.timer_info[name] = (timer[1] - timer[0]) * 1000
if not self.is_synchronized:
for name, events in self.timers["gpu"].items():
assert len(events) == 2
torch.cuda.current_stream().wait_event(events[1])
torch.cuda.synchronize()
self.timer_info[name] = events[0].elapsed_time(events[1])
self.is_synchronized = True

def get_info(self):
if not self.is_synchronized:
self._calculate_elapse_time()
return self.timer_info
3 changes: 3 additions & 0 deletions examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
construct_mlp,
)

from analog.monitor_util import FunctionTimer

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
Expand Down Expand Up @@ -79,6 +81,7 @@
)
_, top_influential_data = torch.topk(if_scores, k=10)

FunctionTimer.print_log()
# Save
if_scores = if_scores.cpu().numpy().tolist()[0]
torch.save(if_scores, "if_analog.pt")
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ torch
einops

pyyaml

Loading