diff --git a/analog/__init__.py b/analog/__init__.py index d65c014d..f02cc866 100644 --- a/analog/__init__.py +++ b/analog/__init__.py @@ -1,4 +1,5 @@ from analog.analog import AnaLog +from analog.scheduler import AnaLogScheduler version = "0.1.0" diff --git a/analog/analog.py b/analog/analog.py index 4970e1c2..aaadb9ae 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -30,6 +30,8 @@ def __init__( """ self.project = project + self.model = None + # Config config = Config(config) self.config = config @@ -51,6 +53,11 @@ def __init__( self.test = False self.mask = None + self.log_default = [] + self.hessian_default = False + self.save_default = False + self.test_default = False + self.type_filter = None self.name_filter = None @@ -70,6 +77,7 @@ def watch( name_filter (list, optional): List of keyword names for modules to be watched. lora (bool, optional): Whether to use LoRA to watch the model. """ + self.model = model self.type_filter = type_filter or self.type_filter self.name_filter = name_filter or self.name_filter @@ -106,7 +114,7 @@ def watch_activation(self, tensor_dict: Dict[str, torch.Tensor]) -> None: def add_lora( self, - model: nn.Module, + model: Optional[nn.Module] = None, parameter_sharing: bool = False, parameter_sharing_groups: List[str] = None, watch: bool = True, @@ -122,6 +130,9 @@ def add_lora( watch (bool, optional): Whether to watch the model or not. clear (bool, optional): Whether to clear the internal states or not. """ + if model is None: + model = self.model + hessian_state = self.hessian_handler.get_hessian_state() self.lora_handler.add_lora( model=model, @@ -180,12 +191,11 @@ def remove_analysis(self, analysis_name: str) -> None: def __call__( self, data_id: Optional[Iterable[Any]] = None, - log: Iterable[str] = [FORWARD, BACKWARD], - hessian: bool = True, - save: bool = False, + log: Optional[Iterable[str]] = None, + hessian: Optional[bool] = None, + save: Optional[bool] = None, test: bool = False, mask: Optional[torch.Tensor] = None, - strategy: Optional[str] = None, ): """ Args: @@ -198,17 +208,15 @@ def __call__( Returns: self: Returns the instance of the AnaLog object. """ - if strategy is None: - self.data_id = data_id - self.log = log - self.hessian = hessian if not test else False - self.save = save if not test else False - self.test = test - self.mask = mask - else: - self.parse_strategy(strategy) + self.data_id = data_id + self.mask = mask + + self.log = log or self.log_default + self.hessian = hessian or self.hessian_default + self.save = save or self.save_default + self.test = test or self.test_default - self.sanity_check(self.data_id, self.log, self.test) + self.sanity_check(self.data_id, self.log, self.hessian, self.save, self.test) return self @@ -236,7 +244,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: This method is essential for ensuring that there are no lingering hooks that could interfere with further operations on the model or with future logging sessions. """ - self.hessian_handler.on_exit(self.logging_handler.current_log) + self.hessian_handler.on_exit(self.logging_handler.current_log, self.hessian) if self.save: self.storage_handler.flush() self.logging_handler.clear() @@ -364,30 +372,21 @@ def finalize( self.hessian_handler.clear() self.storage_handler.clear() - def parse_strategy(self, strategy: str) -> None: - """ - Parses the strategy string to set the internal states. - - Args: - strategy (str): The strategy string. - """ - strategy = strategy.lower() - if strategy == "train": - self.log = [FORWARD, BACKWARD] - self.hessian = True - self.save = False - self.test = False - else: - raise ValueError(f"Unknown strategy: {strategy}") - def sanity_check( - self, data_id: Iterable[Any], log: Iterable[str], test: bool + self, + data_id: Iterable[Any], + log: Iterable[str], + hessian: bool, + save: bool, + test: bool, ) -> None: """ Performs a sanity check on the provided parameters. """ if len(log) > 0 and len(set(log) - LOG_TYPES) > 0: raise ValueError("Invalid value for 'log'.") + if test and (hessian or save): + raise ValueError("Cannot compute Hessian or save logs during testing.") if not test and data_id is None: raise ValueError("Must provide data_id for logging.") if GRAD in log and len(log) > 1: @@ -403,11 +402,16 @@ def ekfac(self, on: bool = True) -> None: else: self.hessian_handler.ekfac = False + def set_default_state(self, log: List[str], hessian: bool, save: bool): + self.log_default = log + self.hessian_default = hessian + self.save_default = save + def reset(self) -> None: """ Reset the internal states. """ - self.log = None + self.log = [] self.hessian = False self.save = False self.test = False diff --git a/analog/hessian/kfac.py b/analog/hessian/kfac.py index abab2fb7..972d0d3d 100644 --- a/analog/hessian/kfac.py +++ b/analog/hessian/kfac.py @@ -27,13 +27,14 @@ def parse_config(self) -> None: self.reduce = self.config.get("reduce", False) @torch.no_grad() - def on_exit(self, current_log=None) -> None: - if self.reduce: - raise NotImplementedError - - if self.ekfac: - for module_name, module_grad in current_log.items(): - self.update_ekfac(module_name, module_grad) + def on_exit(self, current_log=None, update_hessian=True) -> None: + if update_hessian: + if self.reduce: + raise NotImplementedError + + if self.ekfac: + for module_name, module_grad in current_log.items(): + self.update_ekfac(module_name, module_grad) @torch.no_grad() def update_hessian( diff --git a/analog/scheduler.py b/analog/scheduler.py new file mode 100644 index 00000000..e95e181e --- /dev/null +++ b/analog/scheduler.py @@ -0,0 +1,90 @@ +from analog import AnaLog +from analog.utils import get_logger + + +class AnaLogScheduler: + def __init__( + self, + analog: AnaLog, + ekfac: bool = False, + lora: bool = False, + sample: bool = False, + ): + self.analog = analog + + self._epoch = -1 + self.analog_state_schedule = [] + self.execution_schedule = {"ekfac": -1, "lora": -1} + + self.generate_schedule(ekfac, lora, sample) + + def generate_schedule( + self, ekfac: bool = False, lora: bool = False, sample: bool = False + ): + if lora: + self.execution_schedule["lora"] = 1 + if ekfac: + self.execution_schedule["ekfac"] = 1 + int(lora) + + # (log, hessian, save) for analog + if ekfac and lora and sample: + self.analog_state_schedule = [ + ([], True, False), + ([], True, False), + (["grad"], True, False), + (["grad"], False, True), + ] + elif ekfac and lora and not sample: + self.analog_state_schedule = [ + ([], True, False), + ([], True, False), + (["grad"], True, True), + ] + elif ekfac and not lora and sample: + self.analog_state_schedule = [ + ([], True, False), + (["grad"], True, False), + (["grad"], False, True), + ] + elif ekfac and not lora and not sample: + self.analog_state_schedule = [ + ([], True, False), + (["grad"], True, True), + ] + elif not ekfac and lora and sample: + self.analog_state_schedule = [ + ([], True, False), + ([], True, False), + ([grad], False, True), + ] + elif not ekfac and lora and not sample: + self.analog_state_schedule = [ + ([], True, False), + (["grad"], True, True), + ] + elif not ekfac and not lora and sample: + self.analog_state_schedule = [ + ([], True, False), + (["grad"], False, True), + ] + elif not ekfac and not lora and not sample: + self.analog_state_schedule = [ + (["grad"], True, True), + ] + + def __iter__(self): + return self + + def __next__(self): + self._epoch += 1 + if self._epoch < len(self.analog_state_schedule): + self.analog.set_default_state(*self.analog_state_schedule[self._epoch]) + if self._epoch == self.execution_schedule["ekfac"]: + self.analog.ekfac() + if self._epoch == self.execution_schedule["lora"]: + self.analog.add_lora() + return self._epoch + raise StopIteration + + def __len__(self): + return len(self.analog_state_schedule) diff --git a/examples/mnist_influence/compute_influences.py b/examples/mnist_influence/compute_influences.py index 2f1afa98..9dbc25f7 100644 --- a/examples/mnist_influence/compute_influences.py +++ b/examples/mnist_influence/compute_influences.py @@ -41,7 +41,7 @@ id_gen = DataIDGenerator() for inputs, targets in train_loader: data_id = id_gen(inputs) - with analog(data_id=data_id, log=["grad"], save=True): + with analog(data_id=data_id, log=["grad"], hessian=True, save=True): inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) model.zero_grad() outs = model(inputs) diff --git a/tests/examples/test_compute_influences.py b/tests/examples/test_compute_influences.py index c947c495..1028086d 100644 --- a/tests/examples/test_compute_influences.py +++ b/tests/examples/test_compute_influences.py @@ -95,7 +95,7 @@ def test_single_checkpoint_influence(self): id_gen = DataIDGenerator() for inputs, targets in train_loader: data_id = id_gen(inputs) - with analog(data_id=data_id, log=["grad"], save=True): + with analog(data_id=data_id, log=["grad"], hessian=True, save=True): inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) model.zero_grad() outs = model(inputs)