Skip to content

Commit

Permalink
Merge branch 'main' into fix-2
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk authored Nov 25, 2023
2 parents 644aa11 + bfddd80 commit cd66fda
Show file tree
Hide file tree
Showing 19 changed files with 407 additions and 182 deletions.
1 change: 1 addition & 0 deletions analog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from analog.analog import AnaLog
from analog.scheduler import AnaLogScheduler


version = "0.1.0"
74 changes: 39 additions & 35 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from analog.config import Config
from analog.constants import FORWARD, BACKWARD, GRAD, LOG_TYPES
from analog.logging import LoggingHandler
from analog.storage import DefaultStorageHandler, MongoDBStorageHandler
from analog.storage import DefaultStorageHandler
from analog.hessian import RawHessianHandler, KFACHessianHandler
from analog.analysis import AnalysisBase
from analog.lora import LoRAHandler
Expand All @@ -30,6 +30,8 @@ def __init__(
"""
self.project = project

self.model = None

# Config
config = Config(config)
self.config = config
Expand All @@ -51,6 +53,11 @@ def __init__(
self.test = False
self.mask = None

self.log_default = []
self.hessian_default = False
self.save_default = False
self.test_default = False

self.type_filter = None
self.name_filter = None

Expand All @@ -70,6 +77,7 @@ def watch(
name_filter (list, optional): List of keyword names for modules to be watched.
lora (bool, optional): Whether to use LoRA to watch the model.
"""
self.model = model
self.type_filter = type_filter or self.type_filter
self.name_filter = name_filter or self.name_filter

Expand Down Expand Up @@ -106,7 +114,7 @@ def watch_activation(self, tensor_dict: Dict[str, torch.Tensor]) -> None:

def add_lora(
self,
model: nn.Module,
model: Optional[nn.Module] = None,
parameter_sharing: bool = False,
parameter_sharing_groups: List[str] = None,
watch: bool = True,
Expand All @@ -122,6 +130,9 @@ def add_lora(
watch (bool, optional): Whether to watch the model or not.
clear (bool, optional): Whether to clear the internal states or not.
"""
if model is None:
model = self.model

hessian_state = self.hessian_handler.get_hessian_state()
self.lora_handler.add_lora(
model=model,
Expand Down Expand Up @@ -180,12 +191,11 @@ def remove_analysis(self, analysis_name: str) -> None:
def __call__(
self,
data_id: Optional[Iterable[Any]] = None,
log: Iterable[str] = [FORWARD, BACKWARD],
hessian: bool = True,
save: bool = False,
log: Optional[Iterable[str]] = None,
hessian: Optional[bool] = None,
save: Optional[bool] = None,
test: bool = False,
mask: Optional[torch.Tensor] = None,
strategy: Optional[str] = None,
):
"""
Args:
Expand All @@ -198,17 +208,15 @@ def __call__(
Returns:
self: Returns the instance of the AnaLog object.
"""
if strategy is None:
self.data_id = data_id
self.log = log
self.hessian = hessian if not test else False
self.save = save if not test else False
self.test = test
self.mask = mask
else:
self.parse_strategy(strategy)
self.data_id = data_id
self.mask = mask

self.log = log or self.log_default
self.hessian = hessian or self.hessian_default
self.save = save or self.save_default
self.test = test or self.test_default

self.sanity_check(self.data_id, self.log, self.test)
self.sanity_check(self.data_id, self.log, self.hessian, self.save, self.test)

return self

Expand Down Expand Up @@ -236,7 +244,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
This method is essential for ensuring that there are no lingering hooks that could
interfere with further operations on the model or with future logging sessions.
"""
self.hessian_handler.on_exit(self.logging_handler.current_log)
self.hessian_handler.on_exit(self.logging_handler.current_log, self.hessian)
self.storage_handler.flush()
self.logging_handler.clear()

Expand Down Expand Up @@ -363,30 +371,21 @@ def finalize(
self.hessian_handler.clear()
self.storage_handler.clear()

def parse_strategy(self, strategy: str) -> None:
"""
Parses the strategy string to set the internal states.
Args:
strategy (str): The strategy string.
"""
strategy = strategy.lower()
if strategy == "train":
self.log = [FORWARD, BACKWARD]
self.hessian = True
self.save = False
self.test = False
else:
raise ValueError(f"Unknown strategy: {strategy}")

def sanity_check(
self, data_id: Iterable[Any], log: Iterable[str], test: bool
self,
data_id: Iterable[Any],
log: Iterable[str],
hessian: bool,
save: bool,
test: bool,
) -> None:
"""
Performs a sanity check on the provided parameters.
"""
if len(log) > 0 and len(set(log) - LOG_TYPES) > 0:
raise ValueError("Invalid value for 'log'.")
if test and (hessian or save):
raise ValueError("Cannot compute Hessian or save logs during testing.")
if not test and data_id is None:
raise ValueError("Must provide data_id for logging.")
if GRAD in log and len(log) > 1:
Expand All @@ -402,11 +401,16 @@ def ekfac(self, on: bool = True) -> None:
else:
self.hessian_handler.ekfac = False

def set_default_state(self, log: List[str], hessian: bool, save: bool):
self.log_default = log
self.hessian_default = hessian
self.save_default = save

def reset(self) -> None:
"""
Reset the internal states.
"""
self.log = None
self.log = []
self.hessian = False
self.save = False
self.test = False
Expand Down
20 changes: 10 additions & 10 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ def parse_config(self) -> None:
self.reduce = self.config.get("reduce", False)

@torch.no_grad()
def on_exit(self, current_log=None) -> None:
if self.reduce:
raise NotImplementedError
def on_exit(self, current_log=None, update_hessian=True) -> None:
if update_hessian:
if self.reduce:
raise NotImplementedError

if self.ekfac:
for module_name, module_grad in current_log.items():
self.update_ekfac(module_name, module_grad)
if self.ekfac:
for module_name, module_grad in current_log.items():
self.update_ekfac(module_name, module_grad)

@torch.no_grad()
def update_hessian(
Expand All @@ -47,7 +48,7 @@ def update_hessian(
if self.reduce or self.ekfac:
return
# extract activations
activation = self.extract_activations(module, mode, data, mask)
activation = self.extract_activations(module, mode, data)

# compute covariance
covariance = torch.matmul(torch.t(activation), activation).cpu().detach()
Expand Down Expand Up @@ -185,9 +186,8 @@ def extract_activations(
module: nn.Module,
mode: str,
data: torch.Tensor,
mask: Optional[torch.Tensor],
) -> torch.Tensor:
if mode == FORWARD:
return extract_forward_activations(data, module, mask)
return extract_forward_activations(data, module)
assert mode == BACKWARD
return extract_backward_activations(data, module, mask)
return extract_backward_activations(data, module)
28 changes: 0 additions & 28 deletions analog/hessian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def extract_patches(
def extract_forward_activations(
activations: torch.Tensor,
module: nn.Module,
activations_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Extract and reshape activations into valid shapes for covariance computations.
Expand All @@ -65,43 +64,17 @@ def extract_forward_activations(
Raw pre-activations supplied to the module.
module (nn.Module):
The module where the activations are applied.
activations_mask (torch.Tensor, optional):
If padding with dummy inputs is applied to the batch, provide the same mask.
"""
if isinstance(module, nn.Linear):
if (
activations_mask is not None
and activations_mask.shape[:-1] == activations.shape[:-1]
):
activations *= activations_mask
reshaped_activations = activations.reshape(-1, activations.shape[-1])

# ! Ignore bias for now
# if module.bias is not None:
# shape = list(reshaped_activations.shape[:-1]) + [1]
# append_term = reshaped_activations.new_ones(shape)
# if (
# activations_mask is not None
# and activations_mask.shape[:-1] == activations.shape[:-1]
# ):
# append_term *= activations_mask.view(-1, 1)
# reshaped_activations = torch.cat(
# [reshaped_activations, append_term], dim=-1
# )
elif isinstance(module, nn.Conv2d):
del activations_mask
reshaped_activations = extract_patches(
activations, module.kernel_size, module.stride, module.padding
)
reshaped_activations = reshaped_activations.view(
-1, reshaped_activations.size(-1)
)
# ! Ignore bias for now
# if module.bias is not None:
# shape = list(reshaped_activations.shape[:-1]) + [1]
# reshaped_activations = torch.cat(
# [reshaped_activations, reshaped_activations.new_ones(shape)], dim=-1
# )
else:
raise InvalidModuleError()
return reshaped_activations
Expand All @@ -110,7 +83,6 @@ def extract_forward_activations(
def extract_backward_activations(
gradients: torch.Tensor,
module: nn.Module,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Extract and reshape gradients into valid shapes for covariance computations.
Expand Down
4 changes: 2 additions & 2 deletions analog/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _forward_hook_fn(

if self.hessian and self.hessian_type == "kfac":
self.hessian_handler.update_hessian(
module, module_name, FORWARD, activations
module, module_name, FORWARD, activations, self.mask
)

if FORWARD in self.log:
Expand Down Expand Up @@ -111,7 +111,7 @@ def _backward_hook_fn(

if self.hessian and self.hessian_type == "kfac":
self.hessian_handler.update_hessian(
module, module_name, BACKWARD, grad_outputs[0]
module, module_name, BACKWARD, grad_outputs[0], self.mask
)

if BACKWARD in self.log:
Expand Down
14 changes: 13 additions & 1 deletion analog/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def find_parameter_sharing_group(
return found_groups[0]


def _get_submodules(model, key):
"""
Helper function to replace a module with transformers model
https://github.com/huggingface/peft/blob/c0dd27bc974e4a62c6072142146887b75bb2de6c/src/peft/utils/other.py#L251
"""
parent = model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = model.get_submodule(key)
return parent, target, target_name


class LoRAHandler:
"""
Transforms a model into a Lora model.
Expand Down Expand Up @@ -94,4 +105,5 @@ def add_lora(
lora_module.init_weight(self.init_strategy, hessian_state[name])
lora_module.to(device)

setattr(model, name, lora_module)
parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, lora_module)
Loading

0 comments on commit cd66fda

Please sign in to comment.