Skip to content

Commit

Permalink
Merge branch 'main' into flushOnlyWhenNecessary
Browse files Browse the repository at this point in the history
  • Loading branch information
hwijeen committed Nov 24, 2023
2 parents 1a3f941 + bd27357 commit b5bdaf6
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 43 deletions.
1 change: 1 addition & 0 deletions analog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from analog.analog import AnaLog
from analog.scheduler import AnaLogScheduler


version = "0.1.0"
72 changes: 38 additions & 34 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
"""
self.project = project

self.model = None

# Config
config = Config(config)
self.config = config
Expand All @@ -51,6 +53,11 @@ def __init__(
self.test = False
self.mask = None

self.log_default = []
self.hessian_default = False
self.save_default = False
self.test_default = False

self.type_filter = None
self.name_filter = None

Expand All @@ -70,6 +77,7 @@ def watch(
name_filter (list, optional): List of keyword names for modules to be watched.
lora (bool, optional): Whether to use LoRA to watch the model.
"""
self.model = model
self.type_filter = type_filter or self.type_filter
self.name_filter = name_filter or self.name_filter

Expand Down Expand Up @@ -106,7 +114,7 @@ def watch_activation(self, tensor_dict: Dict[str, torch.Tensor]) -> None:

def add_lora(
self,
model: nn.Module,
model: Optional[nn.Module] = None,
parameter_sharing: bool = False,
parameter_sharing_groups: List[str] = None,
watch: bool = True,
Expand All @@ -122,6 +130,9 @@ def add_lora(
watch (bool, optional): Whether to watch the model or not.
clear (bool, optional): Whether to clear the internal states or not.
"""
if model is None:
model = self.model

hessian_state = self.hessian_handler.get_hessian_state()
self.lora_handler.add_lora(
model=model,
Expand Down Expand Up @@ -180,12 +191,11 @@ def remove_analysis(self, analysis_name: str) -> None:
def __call__(
self,
data_id: Optional[Iterable[Any]] = None,
log: Iterable[str] = [FORWARD, BACKWARD],
hessian: bool = True,
save: bool = False,
log: Optional[Iterable[str]] = None,
hessian: Optional[bool] = None,
save: Optional[bool] = None,
test: bool = False,
mask: Optional[torch.Tensor] = None,
strategy: Optional[str] = None,
):
"""
Args:
Expand All @@ -198,17 +208,15 @@ def __call__(
Returns:
self: Returns the instance of the AnaLog object.
"""
if strategy is None:
self.data_id = data_id
self.log = log
self.hessian = hessian if not test else False
self.save = save if not test else False
self.test = test
self.mask = mask
else:
self.parse_strategy(strategy)
self.data_id = data_id
self.mask = mask

self.log = log or self.log_default
self.hessian = hessian or self.hessian_default
self.save = save or self.save_default
self.test = test or self.test_default

self.sanity_check(self.data_id, self.log, self.test)
self.sanity_check(self.data_id, self.log, self.hessian, self.save, self.test)

return self

Expand Down Expand Up @@ -236,7 +244,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
This method is essential for ensuring that there are no lingering hooks that could
interfere with further operations on the model or with future logging sessions.
"""
self.hessian_handler.on_exit(self.logging_handler.current_log)
self.hessian_handler.on_exit(self.logging_handler.current_log, self.hessian)
if self.save:
self.storage_handler.flush()
self.logging_handler.clear()
Expand Down Expand Up @@ -364,30 +372,21 @@ def finalize(
self.hessian_handler.clear()
self.storage_handler.clear()

def parse_strategy(self, strategy: str) -> None:
"""
Parses the strategy string to set the internal states.
Args:
strategy (str): The strategy string.
"""
strategy = strategy.lower()
if strategy == "train":
self.log = [FORWARD, BACKWARD]
self.hessian = True
self.save = False
self.test = False
else:
raise ValueError(f"Unknown strategy: {strategy}")

def sanity_check(
self, data_id: Iterable[Any], log: Iterable[str], test: bool
self,
data_id: Iterable[Any],
log: Iterable[str],
hessian: bool,
save: bool,
test: bool,
) -> None:
"""
Performs a sanity check on the provided parameters.
"""
if len(log) > 0 and len(set(log) - LOG_TYPES) > 0:
raise ValueError("Invalid value for 'log'.")
if test and (hessian or save):
raise ValueError("Cannot compute Hessian or save logs during testing.")
if not test and data_id is None:
raise ValueError("Must provide data_id for logging.")
if GRAD in log and len(log) > 1:
Expand All @@ -403,11 +402,16 @@ def ekfac(self, on: bool = True) -> None:
else:
self.hessian_handler.ekfac = False

def set_default_state(self, log: List[str], hessian: bool, save: bool):
self.log_default = log
self.hessian_default = hessian
self.save_default = save

def reset(self) -> None:
"""
Reset the internal states.
"""
self.log = None
self.log = []
self.hessian = False
self.save = False
self.test = False
Expand Down
15 changes: 8 additions & 7 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ def parse_config(self) -> None:
self.reduce = self.config.get("reduce", False)

@torch.no_grad()
def on_exit(self, current_log=None) -> None:
if self.reduce:
raise NotImplementedError

if self.ekfac:
for module_name, module_grad in current_log.items():
self.update_ekfac(module_name, module_grad)
def on_exit(self, current_log=None, update_hessian=True) -> None:
if update_hessian:
if self.reduce:
raise NotImplementedError

if self.ekfac:
for module_name, module_grad in current_log.items():
self.update_ekfac(module_name, module_grad)

@torch.no_grad()
def update_hessian(
Expand Down
90 changes: 90 additions & 0 deletions analog/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from analog import AnaLog
from analog.utils import get_logger


class AnaLogScheduler:
def __init__(
self,
analog: AnaLog,
ekfac: bool = False,
lora: bool = False,
sample: bool = False,
):
self.analog = analog

self._epoch = -1
self.analog_state_schedule = []
self.execution_schedule = {"ekfac": -1, "lora": -1}

self.generate_schedule(ekfac, lora, sample)

def generate_schedule(
self, ekfac: bool = False, lora: bool = False, sample: bool = False
):
if lora:
self.execution_schedule["lora"] = 1
if ekfac:
self.execution_schedule["ekfac"] = 1 + int(lora)

# (log, hessian, save) for analog
if ekfac and lora and sample:
self.analog_state_schedule = [
([], True, False),
([], True, False),
(["grad"], True, False),
(["grad"], False, True),
]
elif ekfac and lora and not sample:
self.analog_state_schedule = [
([], True, False),
([], True, False),
(["grad"], True, True),
]
elif ekfac and not lora and sample:
self.analog_state_schedule = [
([], True, False),
(["grad"], True, False),
(["grad"], False, True),
]
elif ekfac and not lora and not sample:
self.analog_state_schedule = [
([], True, False),
(["grad"], True, True),
]
elif not ekfac and lora and sample:
self.analog_state_schedule = [
([], True, False),
([], True, False),
([grad], False, True),
]
elif not ekfac and lora and not sample:
self.analog_state_schedule = [
([], True, False),
(["grad"], True, True),
]
elif not ekfac and not lora and sample:
self.analog_state_schedule = [
([], True, False),
(["grad"], False, True),
]
elif not ekfac and not lora and not sample:
self.analog_state_schedule = [
(["grad"], True, True),
]

def __iter__(self):
return self

def __next__(self):
self._epoch += 1
if self._epoch < len(self.analog_state_schedule):
self.analog.set_default_state(*self.analog_state_schedule[self._epoch])
if self._epoch == self.execution_schedule["ekfac"]:
self.analog.ekfac()
if self._epoch == self.execution_schedule["lora"]:
self.analog.add_lora()
return self._epoch
raise StopIteration

def __len__(self):
return len(self.analog_state_schedule)
2 changes: 1 addition & 1 deletion examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id, log=["grad"], save=True):
with analog(data_id=data_id, log=["grad"], hessian=True, save=True):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/test_compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_single_checkpoint_influence(self):
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id, log=["grad"], save=True):
with analog(data_id=data_id, log=["grad"], hessian=True, save=True):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
Expand Down

0 comments on commit b5bdaf6

Please sign in to comment.