Skip to content

Commit

Permalink
update logging setup interface (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 authored Jun 3, 2024
1 parent efd1258 commit 1ca5ec9
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 193 deletions.
2 changes: 1 addition & 1 deletion examples/bert/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
log_loader = logix.build_log_dataloader()

# influence analysis
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for batch in test_loader:
data_id = tokenizer.batch_decode(batch["input_ids"])
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
log_loader = logix.build_log_dataloader()

logix.eval()
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
Expand Down
2 changes: 1 addition & 1 deletion examples/language_modeling/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main():
log_loader = logix.build_log_dataloader(batch_size=64)

# Influence analysis
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
merged_test_logs = []
for idx, batch in enumerate(tqdm(data_loader)):
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)

# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/compute_influences_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
)

# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for test_input, test_target in test_loader:
### Start
Expand Down
2 changes: 1 addition & 1 deletion logix/huggingface/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def on_train_begin(self, args, state, control, **kwargs):
self.logix.initialize_from_log()

if self.args.mode in ["influence", "self_influence"]:
self.logix.setup({"log": "grad"})
self.logix.setup({"grad": ["log"]})
self.logix.eval()

state.epoch = 0
Expand Down
72 changes: 28 additions & 44 deletions logix/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from logix.batch_info import BatchInfo
from logix.config import LoggingConfig
from logix.state import LogIXState
from logix.statistic import Log
from logix.logging.option import LogOption
from logix.logging.log_saver import LogSaver
from logix.logging.utils import compute_per_sample_gradient
Expand Down Expand Up @@ -61,13 +62,12 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

def save_log(self):
# save log to disk
if any(self.opt.save.values()):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
def update(self, save: bool = False):
# gradient plugin has to be excecuted after accumulating all gradients
for stat in self.opt.grad[1:]:
for module_name, _ in self.binfo.log.items():
stat.update(
state=self.state,
Expand All @@ -84,7 +84,8 @@ def update(self):
torch.cuda.current_stream().synchronize()

# Write and flush the buffer if necessary
self.save_log()
if save:
self.save_log()

def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
Expand All @@ -100,7 +101,6 @@ def _forward_hook_fn(
assert len(inputs) == 1

activations = inputs[0]
log = self.binfo.log[module_name]

# If `mask` is not None, apply the mask to activations. This is
# useful for example when you work with sequence models that use
Expand All @@ -118,14 +118,8 @@ def _forward_hook_fn(
if self.dtype is not None:
activations = activations.to(dtype=self.dtype)

if self.opt.log["forward"]:
if "forward" not in log:
log["forward"] = activations
else:
log["forward"] += activations

for stat in self.opt.statistic["forward"]:
stat.update(
for plugin in self.opt.forward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
Expand Down Expand Up @@ -154,19 +148,12 @@ def _backward_hook_fn(
assert len(grad_outputs) == 1

error = grad_outputs[0]
log = self.binfo.log[module_name]

if self.dtype is not None:
error = error.to(dtype=self.dtype)

if self.opt.log["backward"]:
if "backward" not in log:
log["backward"] = error
else:
log["backward"] += error

for stat in self.opt.statistic["backward"]:
stat.update(
for plugin in self.opt.backward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
Expand Down Expand Up @@ -194,24 +181,29 @@ def _grad_hook_fn(
"""
assert len(inputs) == 1

log = self.binfo.log[module_name]

# In case, the same module is used multiple times in the forward pass,
# we need to accumulate the gradients. We achieve this by using the
# additional tensor hook on the output of the module.
def _grad_backward_hook_fn(grad: torch.Tensor):
if self.opt.log["grad"]:
if len(self.opt.grad) > 0:
assert self.opt.grad[0] == Log
per_sample_gradient = compute_per_sample_gradient(
inputs[0], grad, module
)

if self.dtype is not None:
per_sample_gradient = per_sample_gradient.to(dtype=self.dtype)

if "grad" not in log:
log["grad"] = per_sample_gradient
else:
log["grad"] += per_sample_gradient
for plugin in self.opt.grad[:1]:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
module_name=module_name,
log_type="grad",
data=per_sample_gradient,
cpu_offload=self.cpu_offload,
)

tensor_hook = outputs.register_hook(_grad_backward_hook_fn)
self.tensor_hooks.append(tensor_hook)
Expand All @@ -227,15 +219,11 @@ def _tensor_forward_hook_fn(self, tensor: torch.Tensor, tensor_name: str) -> Non
tensor: The tensor triggering the hook.
tensor_name (str): A string identifier for the tensor, useful for logging.
"""
log = self.binfo.log[tensor_name]

if self.dtype is not None:
tensor = tensor.to(dtype=self.dtype)

log["forward"] = tensor

for stat in self.opt.statistic["forward"]:
stat.update(
for plugin in self.opt.forward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=None,
Expand All @@ -256,15 +244,11 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
grad: The gradient tensor triggering the hook.
tensor_name (str): A string identifier for the tensor whose gradient is being tracked.
"""
log = self.binfo.log[tensor_name]

if self.dtype is not None:
grad = grad.to(dtype=self.dtype)

log["backward"] = grad

for stat in self.opt.statistic["backward"]:
stat.update(
for plugin in self.opt.backward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=None,
Expand Down
Loading

0 comments on commit 1ca5ec9

Please sign in to comment.