Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Feb 19, 2024
1 parent 45ea001 commit 60c3696
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 135 deletions.
3 changes: 3 additions & 0 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from analog.lora import LoRAHandler
from analog.lora.utils import is_lora
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import (
get_logger,
get_rank,
Expand Down Expand Up @@ -237,6 +238,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
self.logger.update()

@DeviceFunctionTimer.timer
def build_log_dataset(self):
"""
Constructs the log dataset from the stored logs. This dataset can then be used
Expand All @@ -249,6 +251,7 @@ def build_log_dataset(self):
log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config)
return log_dataset

@DeviceFunctionTimer.timer
def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
):
Expand Down
3 changes: 3 additions & 0 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import einsum, rearrange, reduce
from analog.config import InfluenceConfig
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger, nested_dict
from analog.analysis.utils import synchronize_device

Expand All @@ -24,6 +25,7 @@ def __init__(self, config: InfluenceConfig, state: AnaLogState):
self.influence_scores = pd.DataFrame()
self.flatten = config.flatten

@DeviceFunctionTimer.timer
@torch.no_grad()
def precondition(
self,
Expand Down Expand Up @@ -212,6 +214,7 @@ def flatten_log(self, src):
to_cat.append(log.view(bsz, -1))
return torch.cat(to_cat, dim=1)

@DeviceFunctionTimer.timer
def compute_influence_all(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
Expand Down
3 changes: 3 additions & 0 deletions analog/logging/log_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import torch

from analog.timer.timer import HostFunctionTimer, DeviceFunctionTimer
from analog.utils import nested_dict, to_numpy, get_rank
from analog.logging.mmap import MemoryMapHandler

Expand All @@ -21,6 +22,7 @@ def __init__(self, config, state):
self.buffer = nested_dict()
self.buffer_size = 0

@DeviceFunctionTimer.timer
def buffer_write(self, binfo):
"""
Add log state on exit.
Expand Down Expand Up @@ -85,6 +87,7 @@ def _flush_serialized(self, log_dir) -> str:
del buffer_list
return log_dir

@DeviceFunctionTimer.timer
def flush(self) -> None:
"""
For the DefaultHandler, there's no batch operation needed since each add operation writes to the file.
Expand Down
7 changes: 7 additions & 0 deletions analog/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from analog.logging.option import LogOption
from analog.logging.log_saver import LogSaver
from analog.logging.utils import compute_per_sample_gradient
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger


Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.grad_hooks = []
self.tensor_hooks = []

@DeviceFunctionTimer.timer
def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):
"""
Add log state on exit.
Expand All @@ -59,6 +61,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

return log

@DeviceFunctionTimer.timer
def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
Expand All @@ -82,6 +85,7 @@ def update(self):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

@DeviceFunctionTimer.timer
def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
) -> None:
Expand Down Expand Up @@ -131,6 +135,7 @@ def _forward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _backward_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -172,6 +177,7 @@ def _backward_hook_fn(
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _grad_hook_fn(
self,
module: nn.Module,
Expand Down Expand Up @@ -270,6 +276,7 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def register_all_module_hooks(self) -> None:
"""
Register all module hooks.
Expand Down
1 change: 1 addition & 0 deletions analog/timer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .timer import FunctionTimer, Timer
186 changes: 186 additions & 0 deletions analog/timer/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import logging
import time
import functools

import torch
import psutil
import os


def get_gpu_memory(device_index=None):
return torch.cuda.memory_allocated(device_index)


def get_gpu_max_memory(device_index=None):
return torch.cuda.max_memory_allocated(device_index)


def get_host_memory():
process = psutil.Process(os.getpid())
return process.memory_info().rss


def get_cpu_swap_memory():
return psutil.swap_memory().used


class FunctionTimer:
log = {}

@classmethod
def _wrap_function(cls, func, label, host_timer):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if host_timer:
return cls._host_timer_wrapper(func, label, *args, **kwargs)
else:
return cls._device_timer_wrapper(func, label, *args, **kwargs)

return wrapper

@classmethod
def _host_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_host_memory()
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
after_memory = get_host_memory()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result

@classmethod
def _device_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_gpu_memory()
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
after_memory = get_gpu_memory()
torch.cuda.current_stream().wait_event(end_event)
torch.cuda.synchronize()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result

@classmethod
def timer(cls, label_or_func=None):
host_timer = getattr(
cls, "host_timer", False
) # Fallback to False if not defined

def decorator(func):
label = label_or_func if isinstance(label_or_func, str) else func.__name__
return cls._wrap_function(func, label, host_timer)

if callable(label_or_func):
return decorator(label_or_func)
return decorator

@classmethod
def get_log(cls):
return cls.log

@classmethod
def print_log(cls):
print("Function Timer Logs:")
for label, details in cls.log.items():
print(f" {label}:")
sum_time = 0
for log in details:
for key, value in log.items():
if key == "time_delta":
sum_time += value
print(f" operation costs {sum_time} seconds")


class HostFunctionTimer(FunctionTimer):
host_timer = True


class DeviceFunctionTimer(FunctionTimer):
if torch.cuda.is_available():
host_timer = False
else:
logging.warning("CUDA is not set, setting the timer is set to host timer.")
host_timer = True


class Timer:
def __init__(self):
self.timers = {
"cpu": {},
"gpu": {},
}
self.timer_info = {} # synchronized.
self.is_synchronized = False

def start_timer(self, name, host_timer=False):
if host_timer:
if name in self.timers["cpu"]:
logging.warning(f"timer for {name} already exist")
return
start_time = time.time()
self.timers["cpu"][name] = [start_time]
else:
if name in self.timers["gpu"]:
logging.warning(f"timer for {name} already exist")
return
self.is_synchronized = False
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
self.timers["gpu"][name] = [start_event]

def stop_timer(self, name):
if name in self.timers["cpu"]:
end_time = time.time()
self.timers["cpu"][name].append(end_time)
if name in self.timers["gpu"]:
self.is_synchronized = False
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
self.timers["gpu"][name].append(end_event)

def _calculate_elapse_time(self):
for name, timer in self.timers["cpu"].items():
assert len(timer) == 2
self.timer_info[name] = (timer[1] - timer[0]) * 1000
if not self.is_synchronized:
for name, events in self.timers["gpu"].items():
assert len(events) == 2
torch.cuda.current_stream().wait_event(events[1])
torch.cuda.synchronize()
self.timer_info[name] = events[0].elapsed_time(events[1])
self.is_synchronized = True

def get_info(self):
if not self.is_synchronized:
self._calculate_elapse_time()
return self.timer_info
3 changes: 3 additions & 0 deletions examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
construct_mlp,
)

from analog.timer import FunctionTimer

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
Expand Down Expand Up @@ -79,6 +81,7 @@
)
_, top_influential_data = torch.topk(if_scores, k=10)

FunctionTimer.print_log()
# Save
if_scores = if_scores.cpu().numpy().tolist()[0]
torch.save(if_scores, "if_analog.pt")
Expand Down
Loading

0 comments on commit 60c3696

Please sign in to comment.