Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Feb 19, 2024
1 parent c36340d commit 1f0c55b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 173 deletions.
4 changes: 3 additions & 1 deletion analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from analog.lora import LoRAHandler
from analog.lora.utils import is_lora
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import (
get_logger,
get_rank,
Expand Down Expand Up @@ -236,7 +237,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
interfere with further operations on the model or with future logging sessions.
"""
self.logger.update()

@DeviceFunctionTimer.timer
def build_log_dataset(self):
"""
Constructs the log dataset from the stored logs. This dataset can then be used
Expand All @@ -249,6 +250,7 @@ def build_log_dataset(self):
log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config)
return log_dataset

@DeviceFunctionTimer.timer
def build_log_dataloader(
self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False
):
Expand Down
41 changes: 22 additions & 19 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import einsum, rearrange, reduce
from analog.config import InfluenceConfig
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger, nested_dict
from analog.analysis.utils import synchronize_device

Expand All @@ -24,11 +25,12 @@ def __init__(self, config: InfluenceConfig, state: AnaLogState):
self.influence_scores = pd.DataFrame()
self.flatten = config.flatten

@DeviceFunctionTimer.timer
@torch.no_grad()
def precondition(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
damping: Optional[float] = None,
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
damping: Optional[float] = None,
):
"""
Precondition gradients using the Hessian.
Expand Down Expand Up @@ -84,12 +86,12 @@ def precondition(

@torch.no_grad()
def compute_influence(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute influence scores between two gradient dictionaries.
Expand Down Expand Up @@ -172,10 +174,10 @@ def _dot_product_logs(self, src_module, tgt_module):

@torch.no_grad()
def compute_self_influence(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
precondition: Optional[bool] = True,
damping: Optional[float] = None,
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute self-influence scores. This can be used for uncertainty estimation.
Expand Down Expand Up @@ -212,13 +214,14 @@ def flatten_log(self, src):
to_cat.append(log.view(bsz, -1))
return torch.cat(to_cat, dim=1)

@DeviceFunctionTimer.timer
def compute_influence_all(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
loader: torch.utils.data.DataLoader,
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
loader: torch.utils.data.DataLoader,
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute influence scores against all train data in the log. This can be used
Expand Down
3 changes: 3 additions & 0 deletions analog/logging/log_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import torch

from analog.timer.timer import HostFunctionTimer, DeviceFunctionTimer
from analog.utils import nested_dict, to_numpy, get_rank
from analog.logging.mmap import MemoryMapHandler

Expand All @@ -21,6 +22,7 @@ def __init__(self, config, state):
self.buffer = nested_dict()
self.buffer_size = 0

@DeviceFunctionTimer.timer
def buffer_write(self, binfo):
"""
Add log state on exit.
Expand Down Expand Up @@ -85,6 +87,7 @@ def _flush_serialized(self, log_dir) -> str:
del buffer_list
return log_dir

@DeviceFunctionTimer.timer
def flush(self) -> None:
"""
For the DefaultHandler, there's no batch operation needed since each add operation writes to the file.
Expand Down
41 changes: 23 additions & 18 deletions analog/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from analog.logging.option import LogOption
from analog.logging.log_saver import LogSaver
from analog.logging.utils import compute_per_sample_gradient
from analog.timer.timer import DeviceFunctionTimer
from analog.utils import get_logger


class HookLogger:
def __init__(
self,
config: LoggingConfig,
state: AnaLogState,
binfo: BatchInfo,
self,
config: LoggingConfig,
state: AnaLogState,
binfo: BatchInfo,
) -> None:
"""
Initializes the LoggingHandler with empty lists for hooks.
Expand All @@ -42,6 +43,7 @@ def __init__(
self.grad_hooks = []
self.tensor_hooks = []

@DeviceFunctionTimer.timer
def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):
"""
Add log state on exit.
Expand All @@ -59,6 +61,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

return log

@DeviceFunctionTimer.timer
def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
Expand All @@ -82,8 +85,9 @@ def update(self):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

@DeviceFunctionTimer.timer
def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
) -> None:
"""
Internal forward hook function.
Expand Down Expand Up @@ -130,13 +134,13 @@ def _forward_hook_fn(
data=activations,
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _backward_hook_fn(
self,
module: nn.Module,
grad_inputs: Tuple[torch.Tensor],
grad_outputs: Tuple[torch.Tensor],
module_name: str,
self,
module: nn.Module,
grad_inputs: Tuple[torch.Tensor],
grad_outputs: Tuple[torch.Tensor],
module_name: str,
) -> None:
"""
Internal backward hook function.
Expand Down Expand Up @@ -171,13 +175,13 @@ def _backward_hook_fn(
data=error,
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def _grad_hook_fn(
self,
module: nn.Module,
inputs: Tuple[torch.Tensor],
outputs: Tuple[torch.Tensor],
module_name: str,
self,
module: nn.Module,
inputs: Tuple[torch.Tensor],
outputs: Tuple[torch.Tensor],
module_name: str,
) -> None:
"""
Internal gradient hook function.
Expand Down Expand Up @@ -270,6 +274,7 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
cpu_offload=self.cpu_offload,
)

@DeviceFunctionTimer.timer
def register_all_module_hooks(self) -> None:
"""
Register all module hooks.
Expand Down Expand Up @@ -333,7 +338,7 @@ def finalize(self):
self.log_saver.finalize()

def clear(
self, hook: bool = True, module: bool = True, buffer: bool = True
self, hook: bool = True, module: bool = True, buffer: bool = True
) -> None:
"""
Clear all hooks and internal states.
Expand Down
3 changes: 3 additions & 0 deletions examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
construct_mlp,
)

from analog.timer import FunctionTimer

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
Expand Down Expand Up @@ -79,6 +81,7 @@
)
_, top_influential_data = torch.topk(if_scores, k=10)

FunctionTimer.print_log()
# Save
if_scores = if_scores.cpu().numpy().tolist()[0]
torch.save(if_scores, "if_analog.pt")
Expand Down
135 changes: 0 additions & 135 deletions tests/util/timer.py

This file was deleted.

0 comments on commit 1f0c55b

Please sign in to comment.