Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timer function wrapper and instant timer #90

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from analog.lora import LoRAHandler
from analog.lora.utils import is_lora
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import (
get_logger,
get_rank,
Expand Down Expand Up @@ -237,6 +238,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
self.logger.update()

@DeviceFunctionTimer.timer
def build_log_dataset(self):
"""
Constructs the log dataset from the stored logs. This dataset can then be used
Expand All @@ -249,6 +251,7 @@ def build_log_dataset(self):
log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config)
return log_dataset

@DeviceFunctionTimer.timer
def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
):
Expand Down
3 changes: 3 additions & 0 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import einsum, rearrange, reduce
from analog.config import InfluenceConfig
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger, nested_dict
from analog.analysis.utils import synchronize_device

Expand All @@ -24,6 +25,7 @@ def __init__(self, config: InfluenceConfig, state: AnaLogState):
self.influence_scores = pd.DataFrame()
self.flatten = config.flatten

@DeviceFunctionTimer.timer
@torch.no_grad()
def precondition(
self,
Expand Down Expand Up @@ -212,6 +214,7 @@ def flatten_log(self, src):
to_cat.append(log.view(bsz, -1))
return torch.cat(to_cat, dim=1)

@DeviceFunctionTimer.timer
def compute_influence_all(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
Expand Down
3 changes: 3 additions & 0 deletions analog/logging/log_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import torch

from analog.timer.timer import HostFunctionTimer, DeviceFunctionTimer
from analog.utils import nested_dict, to_numpy, get_rank
from analog.logging.mmap import MemoryMapHandler

Expand All @@ -21,6 +22,7 @@ def __init__(self, config, state):
self.buffer = nested_dict()
self.buffer_size = 0

@DeviceFunctionTimer.timer
def buffer_write(self, binfo):
"""
Add log state on exit.
Expand Down Expand Up @@ -85,6 +87,7 @@ def _flush_serialized(self, log_dir) -> str:
del buffer_list
return log_dir

@DeviceFunctionTimer.timer
def flush(self) -> None:
"""
For the DefaultHandler, there's no batch operation needed since each add operation writes to the file.
Expand Down
7 changes: 7 additions & 0 deletions analog/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from analog.logging.option import LogOption
from analog.logging.log_saver import LogSaver
from analog.logging.utils import compute_per_sample_gradient
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger


Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.grad_hooks = []
self.tensor_hooks = []

@DeviceFunctionTimer.timer
def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):
"""
Add log state on exit.
Expand All @@ -59,6 +61,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

return log

@DeviceFunctionTimer.timer
def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
Expand All @@ -82,6 +85,7 @@ def update(self):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

@DeviceFunctionTimer.timer
def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
) -> None:
Expand Down Expand Up @@ -131,6 +135,7 @@ def _forward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _backward_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -172,6 +177,7 @@ def _backward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _grad_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -270,6 +276,7 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def register_all_module_hooks(self) -> None:
"""
Register all module hooks.
Expand Down
1 change: 1 addition & 0 deletions analog/timer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .timer import FunctionTimer, Timer
186 changes: 186 additions & 0 deletions analog/timer/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import logging
import time
import functools

import torch
import psutil
import os


def get_gpu_memory(device_index=None):
return torch.cuda.memory_allocated(device_index)


def get_gpu_max_memory(device_index=None):
return torch.cuda.max_memory_allocated(device_index)


def get_host_memory():
process = psutil.Process(os.getpid())
return process.memory_info().rss


def get_cpu_swap_memory():
return psutil.swap_memory().used


class FunctionTimer:
log = {}

@classmethod
def _wrap_function(cls, func, label, host_timer):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if host_timer:
return cls._host_timer_wrapper(func, label, *args, **kwargs)
else:
return cls._device_timer_wrapper(func, label, *args, **kwargs)

return wrapper

@classmethod
def _host_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_host_memory()
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
after_memory = get_host_memory()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result

@classmethod
def _device_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_gpu_memory()
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
after_memory = get_gpu_memory()
torch.cuda.current_stream().wait_event(end_event)
torch.cuda.synchronize()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result

@classmethod
def timer(cls, label_or_func=None):
host_timer = getattr(
cls, "host_timer", False
) # Fallback to False if not defined

def decorator(func):
label = label_or_func if isinstance(label_or_func, str) else func.__name__
return cls._wrap_function(func, label, host_timer)

if callable(label_or_func):
return decorator(label_or_func)
return decorator

@classmethod
def get_log(cls):
return cls.log

@classmethod
def print_log(cls):
print("Function Timer Logs:")
for label, details in cls.log.items():
print(f" {label}:")
sum_time = 0
for log in details:
for key, value in log.items():
if key == "time_delta":
sum_time += value
print(f" operation costs {sum_time} seconds")


class HostFunctionTimer(FunctionTimer):
host_timer = True


class DeviceFunctionTimer(FunctionTimer):
if torch.cuda.is_available():
host_timer = False
else:
logging.warning("CUDA is not set, setting the timer is set to host timer.")
host_timer = True


class Timer:
def __init__(self):
self.timers = {
"cpu": {},
"gpu": {},
}
self.timer_info = {} # synchronized.
self.is_synchronized = False

def start_timer(self, name, host_timer=False):
if host_timer:
if name in self.timers["cpu"]:
logging.warning(f"timer for {name} already exist")
return
start_time = time.time()
self.timers["cpu"][name] = [start_time]
else:
if name in self.timers["gpu"]:
logging.warning(f"timer for {name} already exist")
return
self.is_synchronized = False
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
self.timers["gpu"][name] = [start_event]

def stop_timer(self, name):
if name in self.timers["cpu"]:
end_time = time.time()
self.timers["cpu"][name].append(end_time)
if name in self.timers["gpu"]:
self.is_synchronized = False
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
self.timers["gpu"][name].append(end_event)

def _calculate_elapse_time(self):
for name, timer in self.timers["cpu"].items():
assert len(timer) == 2
self.timer_info[name] = (timer[1] - timer[0]) * 1000
if not self.is_synchronized:
for name, events in self.timers["gpu"].items():
assert len(events) == 2
torch.cuda.current_stream().wait_event(events[1])
torch.cuda.synchronize()
self.timer_info[name] = events[0].elapsed_time(events[1])
self.is_synchronized = True

def get_info(self):
if not self.is_synchronized:
self._calculate_elapse_time()
return self.timer_info
3 changes: 3 additions & 0 deletions examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
construct_mlp,
)

from analog.timer import FunctionTimer

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
Expand Down Expand Up @@ -79,6 +81,7 @@
)
_, top_influential_data = torch.topk(if_scores, k=10)

FunctionTimer.print_log()
# Save
if_scores = if_scores.cpu().numpy().tolist()[0]
torch.save(if_scores, "if_analog.pt")
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ numpy
pandas
torch
einops

psutil
Copy link
Collaborator Author

@eatpk eatpk Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can just remove the psutil, since this is for the memory monitoring.. and it seems like the memory monitoring is not really doing anything at this point.(calculating the memory difference of before & after) We may need to do this via profiling as mentioned before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can remove psutil for now. We can add it back once there is a request for CPU memory tracking from the users.

pyyaml

Loading