diff --git a/examples/bert/compute_influence.py b/examples/bert/compute_influence.py index e67e639b..3d0d9247 100644 --- a/examples/bert/compute_influence.py +++ b/examples/bert/compute_influence.py @@ -35,7 +35,7 @@ def main(): log_loader = logix.build_log_dataloader() # influence analysis - logix.setup({"log": "grad"}) + logix.setup({"grad": ["log"]}) logix.eval() for batch in test_loader: data_id = tokenizer.batch_decode(batch["input_ids"]) diff --git a/examples/cifar/compute_influences.py b/examples/cifar/compute_influences.py index b059dfe4..6fe09f27 100644 --- a/examples/cifar/compute_influences.py +++ b/examples/cifar/compute_influences.py @@ -64,7 +64,7 @@ log_loader = logix.build_log_dataloader() logix.eval() -logix.setup({"log": "grad"}) +logix.setup({"grad": ["log"]}) for test_input, test_target in test_loader: with logix(data_id=id_gen(test_input)): test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) diff --git a/examples/language_modeling/compute_influence.py b/examples/language_modeling/compute_influence.py index 6c4ce011..f5ea354f 100644 --- a/examples/language_modeling/compute_influence.py +++ b/examples/language_modeling/compute_influence.py @@ -74,7 +74,7 @@ def main(): log_loader = logix.build_log_dataloader(batch_size=64) # Influence analysis - logix.setup({"log": "grad"}) + logix.setup({"grad": ["log"]}) logix.eval() merged_test_logs = [] for idx, batch in enumerate(tqdm(data_loader)): diff --git a/examples/mnist/compute_influences.py b/examples/mnist/compute_influences.py index 6405f350..29b9d390 100644 --- a/examples/mnist/compute_influences.py +++ b/examples/mnist/compute_influences.py @@ -62,7 +62,7 @@ ) # logix.add_analysis({"influence": InfluenceFunction}) -logix.setup({"log": "grad"}) +logix.setup({"grad": ["log"]}) logix.eval() for test_input, test_target in test_loader: with logix(data_id=id_gen(test_input)): diff --git a/examples/mnist/compute_influences_manual.py b/examples/mnist/compute_influences_manual.py index e9f892ef..5956a48c 100644 --- a/examples/mnist/compute_influences_manual.py +++ b/examples/mnist/compute_influences_manual.py @@ -67,7 +67,7 @@ ) # logix.add_analysis({"influence": InfluenceFunction}) -logix.setup({"log": "grad"}) +logix.setup({"grad": ["log"]}) logix.eval() for test_input, test_target in test_loader: ### Start diff --git a/logix/huggingface/callback.py b/logix/huggingface/callback.py index a0529235..4b9b589b 100644 --- a/logix/huggingface/callback.py +++ b/logix/huggingface/callback.py @@ -46,7 +46,7 @@ def on_train_begin(self, args, state, control, **kwargs): self.logix.initialize_from_log() if self.args.mode in ["influence", "self_influence"]: - self.logix.setup({"log": "grad"}) + self.logix.setup({"grad": ["log"]}) self.logix.eval() state.epoch = 0 diff --git a/logix/logging/logger.py b/logix/logging/logger.py index 6b596a31..c3d66d7a 100644 --- a/logix/logging/logger.py +++ b/logix/logging/logger.py @@ -7,6 +7,7 @@ from logix.batch_info import BatchInfo from logix.config import LoggingConfig from logix.state import LogIXState +from logix.statistic import Log from logix.logging.option import LogOption from logix.logging.log_saver import LogSaver from logix.logging.utils import compute_per_sample_gradient @@ -61,13 +62,12 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None): def save_log(self): # save log to disk - if any(self.opt.save.values()): - self.log_saver.buffer_write(binfo=self.binfo) - self.log_saver.flush() + self.log_saver.buffer_write(binfo=self.binfo) + self.log_saver.flush() - def update(self): - # Update statistics - for stat in self.opt.statistic["grad"]: + def update(self, save: bool = False): + # gradient plugin has to be excecuted after accumulating all gradients + for stat in self.opt.grad[1:]: for module_name, _ in self.binfo.log.items(): stat.update( state=self.state, @@ -84,7 +84,8 @@ def update(self): torch.cuda.current_stream().synchronize() # Write and flush the buffer if necessary - self.save_log() + if save: + self.save_log() def _forward_hook_fn( self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str @@ -100,7 +101,6 @@ def _forward_hook_fn( assert len(inputs) == 1 activations = inputs[0] - log = self.binfo.log[module_name] # If `mask` is not None, apply the mask to activations. This is # useful for example when you work with sequence models that use @@ -118,14 +118,8 @@ def _forward_hook_fn( if self.dtype is not None: activations = activations.to(dtype=self.dtype) - if self.opt.log["forward"]: - if "forward" not in log: - log["forward"] = activations - else: - log["forward"] += activations - - for stat in self.opt.statistic["forward"]: - stat.update( + for plugin in self.opt.forward: + plugin.update( state=self.state, binfo=self.binfo, module=module, @@ -154,19 +148,12 @@ def _backward_hook_fn( assert len(grad_outputs) == 1 error = grad_outputs[0] - log = self.binfo.log[module_name] if self.dtype is not None: error = error.to(dtype=self.dtype) - if self.opt.log["backward"]: - if "backward" not in log: - log["backward"] = error - else: - log["backward"] += error - - for stat in self.opt.statistic["backward"]: - stat.update( + for plugin in self.opt.backward: + plugin.update( state=self.state, binfo=self.binfo, module=module, @@ -194,13 +181,12 @@ def _grad_hook_fn( """ assert len(inputs) == 1 - log = self.binfo.log[module_name] - # In case, the same module is used multiple times in the forward pass, # we need to accumulate the gradients. We achieve this by using the # additional tensor hook on the output of the module. def _grad_backward_hook_fn(grad: torch.Tensor): - if self.opt.log["grad"]: + if len(self.opt.grad) > 0: + assert self.opt.grad[0] == Log per_sample_gradient = compute_per_sample_gradient( inputs[0], grad, module ) @@ -208,10 +194,16 @@ def _grad_backward_hook_fn(grad: torch.Tensor): if self.dtype is not None: per_sample_gradient = per_sample_gradient.to(dtype=self.dtype) - if "grad" not in log: - log["grad"] = per_sample_gradient - else: - log["grad"] += per_sample_gradient + for plugin in self.opt.grad[:1]: + plugin.update( + state=self.state, + binfo=self.binfo, + module=module, + module_name=module_name, + log_type="grad", + data=per_sample_gradient, + cpu_offload=self.cpu_offload, + ) tensor_hook = outputs.register_hook(_grad_backward_hook_fn) self.tensor_hooks.append(tensor_hook) @@ -227,15 +219,11 @@ def _tensor_forward_hook_fn(self, tensor: torch.Tensor, tensor_name: str) -> Non tensor: The tensor triggering the hook. tensor_name (str): A string identifier for the tensor, useful for logging. """ - log = self.binfo.log[tensor_name] - if self.dtype is not None: tensor = tensor.to(dtype=self.dtype) - log["forward"] = tensor - - for stat in self.opt.statistic["forward"]: - stat.update( + for plugin in self.opt.forward: + plugin.update( state=self.state, binfo=self.binfo, module=None, @@ -256,15 +244,11 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None grad: The gradient tensor triggering the hook. tensor_name (str): A string identifier for the tensor whose gradient is being tracked. """ - log = self.binfo.log[tensor_name] - if self.dtype is not None: grad = grad.to(dtype=self.dtype) - log["backward"] = grad - - for stat in self.opt.statistic["backward"]: - stat.update( + for plugin in self.opt.backward: + plugin.update( state=self.state, binfo=self.binfo, module=None, diff --git a/logix/logging/option.py b/logix/logging/option.py index b8b6f68a..90b4ad29 100644 --- a/logix/logging/option.py +++ b/logix/logging/option.py @@ -1,14 +1,67 @@ -from typing import Any +from typing import List, Any -from logix.statistic import Covariance, CorrectedEigval +from logix.statistic import Log, Mean, Variance, Covariance, CorrectedEigval from logix.utils import get_logger +_PLUGIN_MAPPING = { + "log": Log, + "mean": Mean, + "variance": Variance, + "covariance": Covariance, + "corrected_eigval": CorrectedEigval, +} +_PLUGIN_LIST = [Log, Mean, Variance, Covariance, CorrectedEigval] + + +def _reorder_plugins(plugins): + """ + Reorder the plugins to ensure that the plugins are in the correct order. Especially, + it is important to ensure that the Log plugin is the first plugin. + Args: + plugins: List of plugins. + Returns: + List of plugins in the correct order. + """ + order = [Log, Mean, Variance, Covariance, CorrectedEigval] + ordered_plugins = [] + for plugin in order: + if plugin in plugins: + ordered_plugins.append(plugin) + return ordered_plugins + + +def _to_plugins(plugins: List[Any], is_grad: bool = False): + """ + Convert and reorder the list of plugins to the actual plugins. + """ + # Convert the string plugins to the actual plugins. + for idx, plugin in enumerate(plugins): + if isinstance(plugin, str): + assert plugin in _PLUGIN_MAPPING + plugins[idx] = _PLUGIN_MAPPING[plugin] + assert plugins[idx] in _PLUGIN_LIST + + # reorder the plugins to ensure that the plugins are in the correct order. + plugins = _reorder_plugins(plugins) + + if is_grad: + # Ensure that the Log plugin is the first plugin. + if len(plugins) > 0 and Log not in plugins: + get_logger().warning( + "The `Log` plugin is not in the list of plugins. " + "The `Log` plugin will be inserted at the beginning of the list." + ) + plugins.insert(0, Log) + + return plugins + + class LogOption: def __init__(self): - self._log = {} - self._save = {} - self._statistic = {} + self.forward = [] + self.backward = [] + self.grad = [] self.clear() @@ -21,108 +74,20 @@ def setup(self, log_option_kwargs): save: Saving configurations. statistic: Statistic configurations. """ - log = log_option_kwargs.get("log", None) - save = log_option_kwargs.get("save", None) - statistic = log_option_kwargs.get("statistic", None) self.clear() - if log is not None: - if isinstance(log, str): - self._log[log] = True - elif isinstance(log, list): - for l in log: - self._log[l] = True - elif isinstance(log, dict): - self._log = log - else: - raise ValueError(f"Unsupported log type: {type(log)}") - - if save is not None: - if isinstance(save, str): - self._save[save] = True - elif isinstance(save, list): - for s in save: - self._save[s] = True - elif isinstance(save, dict): - self._save = save - else: - raise ValueError(f"Unsupported save type: {type(save)}") - - if statistic is not None: - if isinstance(statistic, str): - if statistic == "kfac": - statistic = { - "forward": [Covariance], - "backward": [Covariance], - "grad": [], - } - elif statistic == "ekfac": - statistic = { - "forward": [], - "backward": [], - "grad": [CorrectedEigval], - } - else: - raise ValueError(f"Unknown statistic: {statistic}") - - assert isinstance(statistic, dict) - self._statistic = statistic - - self._sanity_check() - - def _sanity_check(self): - # forward - if self._save["forward"] and not self._log["forward"]: - get_logger().warning( - "Saving forward activations without logging it is not allowed. " - + "Setting log['forward'] to True automatically." - ) - self._log["forward"] = True + forward = log_option_kwargs.get("forward", []) + backward = log_option_kwargs.get("backward", []) + grad = log_option_kwargs.get("grad", []) - # backward - if self._save["backward"] and not self._log["backward"]: - get_logger().warning( - "Saving backward error signals without logging it is not allowed. " - + "Setting log['backward'] to True automatically." - ) - self._log["backward"] = True - - # grad - if (self._save["grad"] or len(self._statistic["grad"]) > 0) and not self._log[ - "grad" - ]: - get_logger().warning( - "Saving gradients or computing statistic without logging it " - + "is not allowed. Setting log['grad'] to True automatically." - ) - self._log["grad"] = True - - def eval(self): - """ - Enable the evaluation mode. This will turn of saving and updating - statistic. - """ - self.clear(log=False, save=True, statistic=True) + self.forward = _to_plugins(forward) + self.backward = _to_plugins(backward) + self.grad = _to_plugins(grad) - def clear(self, log=True, save=True, statistic=True): + def clear(self): """ Clear all logging configurations. """ - if log: - self._log = {"forward": False, "backward": False, "grad": False} - if save: - self._save = {"forward": False, "backward": False, "grad": False} - if statistic: - self._statistic = {"forward": [], "backward": [], "grad": []} - - @property - def log(self): - return self._log - - @property - def save(self): - return self._save - - @property - def statistic(self): - return self._statistic + self.forward = [] + self.backward = [] + self.grad = [] diff --git a/logix/logix.py b/logix/logix.py index f8966e25..eb462bd3 100644 --- a/logix/logix.py +++ b/logix/logix.py @@ -69,6 +69,8 @@ def __init__( # LogIX state self.state: LogIXState = LogIXState() self.binfo: BatchInfo = BatchInfo() + self._save: bool = False + self._save_batch: bool = False # Initialize logger self.logger: HookLogger = HookLogger( @@ -233,6 +235,7 @@ def __call__( self, data_id: Iterable[Any], mask: Optional[torch.Tensor] = None, + save: bool = False, ): """ Args: @@ -246,6 +249,8 @@ def __call__( self.binfo.data_id = data_id self.binfo.mask = mask + self._save_batch = save + return self def __enter__(self): @@ -267,7 +272,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: This method is essential for ensuring that there are no lingering hooks that could interfere with further operations on the model or with future logging sessions. """ - self.logger.update() + self.logger.update(save=self._save_batch or self._save) def start( self, data_id: Iterable[Any], mask: Optional[torch.Tensor] = None @@ -295,7 +300,7 @@ def end(self, save: bool = False) -> None: This is another programming interface for logging. Instead of using the context manager, we also allow users to manually specify "start" and "end" points for logging. """ - self.logger.update() + self.logger.update(save=save or self._save) def build_log_dataset(self, flatten: bool = False) -> torch.utils.data.Dataset: """ @@ -528,11 +533,17 @@ def setup(self, log_option_kwargs: Dict[str, Any]) -> None: """ self.logger.opt.setup(log_option_kwargs) + def save(self, enable: bool = True) -> None: + """ + Turn on saving. + """ + self._save = enable + def eval(self) -> None: """ Set the state of LogIX for testing. """ - self.logger.opt.eval() + self.save(False) def clear(self) -> None: """ diff --git a/logix/scheduler.py b/logix/scheduler.py index 43f38a5a..43cd1db9 100644 --- a/logix/scheduler.py +++ b/logix/scheduler.py @@ -1,5 +1,5 @@ from logix import LogIX -from logix.statistic import Covariance +from logix.statistic import Covariance, Log, CorrectedEigval class LogIXScheduler: @@ -10,65 +10,81 @@ def __init__( hessian: str = "none", save: str = "none", ): - self.logix = logix + self._logix = logix + + self._lora = lora + self._hessian = hessian + self._save = save self._epoch = -1 - self._lora_epoch = -1 self._logix_state_schedule = [] - self.sanity_check(lora, hessian, save) - self.configure_lora_epoch(lora) + self.sanity_check(lora, hessian) self.configure_schedule(lora, hessian, save) self._schedule_iterator = iter(self._logix_state_schedule) - def sanity_check(self, lora: str, hessian: str, save: str): + def sanity_check(self, lora: str, hessian: str): assert lora in ["none", "random", "pca"] assert hessian in ["none", "raw", "kfac", "ekfac"] - assert save in ["none", "grad"] - def configure_lora_epoch(self, lora: str): + def get_lora_epoch(self, lora: str) -> int: if lora == "random": - self._lora_epoch = 0 + return 0 elif lora == "pca": - self._lora_epoch = 1 + return 1 + return -1 + + def get_save_epoch(self, save: str) -> int: + if save != "none": + return len(self) - 1 + return -1 - def configure_schedule(self, lora: str, hessian: str, save: str): - # (log, hessian, save) for logix + def configure_schedule(self, lora: str, hessian: str, save: str) -> None: if lora == "pca": - self._logix_state_schedule.append({"statistic": "kfac"}) + self._logix_state_schedule.append( + {"forward": [Covariance], "backward": [Covariance]} + ) if hessian == "ekfac": - self._logix_state_schedule.append({"statistic": "kfac"}) - - last_state = {} - # log - if save in ["grad"] or hessian in ["raw", "ekfac"]: - last_state["log"] = "grad" - # statistic - if hessian in ["kfac", "ekfac"]: - last_state["statistic"] = hessian - elif hessian in ["raw"]: - last_state["statistic"] = { - "grad": [Covariance], - "forward": [], - "backward": [], - } - # save - if save in ["grad"]: - last_state["save"] = save + self._logix_state_schedule.append( + {"forward": [Covariance], "backward": [Covariance]} + ) + + last_state = {"forward": [], "backward": [], "grad": []} + if save != "none": + last_state[save].append(Log) + if hessian == "kfac": + last_state["forward"].append(Covariance) + last_state["backward"].append(Covariance) + elif hessian == "ekfac": + if Log not in last_state["grad"]: + last_state["grad"].append(Log) + last_state["grad"].append(CorrectedEigval) + elif hessian == "raw": + if Log not in last_state["grad"]: + last_state["grad"].append(Log) + last_state["grad"].append(Covariance) self._logix_state_schedule.append(last_state) def __iter__(self): return self - def __next__(self): + def __next__(self) -> int: logix_state = next(self._schedule_iterator) self._epoch += 1 - if self._epoch == self._lora_epoch: - self.logix.add_lora() - self.logix.setup(logix_state) + + # maybe add lora + if self._epoch == self.get_lora_epoch(self._lora): + self._logix.add_lora() + + # maybe setup save + if self._epoch == self.get_save_epoch(self._save): + self._logix.save(True) + + self._logix.setup(logix_state) + return self._epoch - def __len__(self): + def __len__(self) -> int: return len(self._logix_state_schedule) diff --git a/logix/statistic/log.py b/logix/statistic/log.py index ae745672..cc40ad04 100644 --- a/logix/statistic/log.py +++ b/logix/statistic/log.py @@ -22,7 +22,7 @@ def update( """ Put log into `binfo` """ - module_log = binfo[module_name] + module_log = binfo.log[module_name] if log_type not in module_log: module_log[log_type] = data else: diff --git a/tests/examples/test_compute_influences.py b/tests/examples/test_compute_influences.py index 0e6c0159..b365deeb 100644 --- a/tests/examples/test_compute_influences.py +++ b/tests/examples/test_compute_influences.py @@ -112,6 +112,7 @@ def test_single_checkpoint_influence(self): # logix.add_analysis({"influence": InfluenceFunction}) query_iter = iter(query_loader) logix.eval() + logix.setup({"grad": ["log"]}) with logix(data_id=["test"]) as al: test_input, test_target = next(query_iter) test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)