Skip to content

Commit

Permalink
Merge pull request #30 from sangkeun00/hessian_logging
Browse files Browse the repository at this point in the history
Hessian logging
  • Loading branch information
sangkeun00 authored Nov 30, 2023
2 parents 4f9ecd3 + 74960b0 commit 0f44c4f
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
15 changes: 14 additions & 1 deletion analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.model = None

# Config
config = Config(config)
config = Config(config_file=config, project_name=project)
self.config = config

# Initialize storage, hessian, and logging handlers from config as well as
Expand Down Expand Up @@ -271,6 +271,7 @@ def build_hessian_handler(self):
Returns:
The initialized Hessian handler.
"""
global_config = self.config.get_global_config()
hessian_config = self.config.get_hessian_config()
hessian_type = hessian_config.get("type", "kfac")
if hessian_type == "kfac":
Expand Down Expand Up @@ -358,6 +359,18 @@ def hessian_svd(self):
"""
return self.hessian_handler.hessian_svd()

def save_hessian(self):
"""
Save Hessian state to disk.
"""
self.hessian_handler.save_state()

def load_hessian(self, log_dir: str):
"""
Load Hessian state from disk.
"""
self.hessian_handler.load_state(log_dir=log_dir)

def finalize(
self,
clear: bool = False,
Expand Down
27 changes: 24 additions & 3 deletions analog/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, Any
import yaml

Expand All @@ -12,15 +13,15 @@ class Config:

# Default values for each configuration
_DEFAULTS = {
"global": {},
"global": {"log_root": "./analog"},
"logging": {},
"storage": {"type": "default", "log_dir": "./analog"},
"storage": {"type": "default"},
"hessian": {"type": "kfac", "damping": 1e-2},
"analysis": {},
"lora": {"init": "pca", "rank": 64},
}

def __init__(self, config_file: str) -> None:
def __init__(self, config_file: str, project_name: str) -> None:
"""
Initialize Config class with given configuration file.
Expand All @@ -35,6 +36,9 @@ def __init__(self, config_file: str) -> None:
)
self.data = {}

self.project_name = project_name
self._set_log_dir()

def get_global_config(self) -> Dict[str, Any]:
"""
Retrieve global configuration.
Expand Down Expand Up @@ -82,3 +86,20 @@ def get_lora_config(self) -> Dict[str, Any]:
:return: Dictionary containing LoRA configurations.
"""
return self.data.get("lora", self._DEFAULTS["lora"])

def _set_log_dir(self) -> None:
"""
Set single logging directory for all components.
"""
log_root = self.get_global_config().get("log_root")
log_dir = os.path.join(log_root, self.project_name)
if len(self.data) == 0:
self._DEFAULTS["global"]["log_dir"] = log_dir
self._DEFAULTS["logging"]["log_dir"] = log_dir
self._DEFAULTS["storage"]["log_dir"] = log_dir
self._DEFAULTS["hessian"]["log_dir"] = log_dir + "/hessian"
else:
self.data["global"]["log_dir"] = log_dir
self.data["logging"]["log_dir"] = log_dir
self.data["storage"]["log_dir"] = log_dir
self.data["hessian"]["log_dir"] = log_dir + "/hessian"
60 changes: 60 additions & 0 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional

import torch
Expand Down Expand Up @@ -26,6 +27,7 @@ def __init__(self, config: dict) -> None:
self.ekfac_state_unsync = False

def parse_config(self) -> None:
self.log_dir = self.config.get("log_dir")
self.damping = self.config.get("damping", 1e-2)
self.reduce = self.config.get("reduce", False)

Expand Down Expand Up @@ -206,3 +208,61 @@ def extract_activations(
return extract_forward_activations(data, module)
assert mode == BACKWARD
return extract_backward_activations(data, module)

def save_state(self) -> None:
"""
Save Hessian state to disk.
"""
# TODO: should this be in the constructor or initialize-type function?
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

# TODO: implement this for all HessianHandlers
if hasattr(self, "hessian_state"):
torch.save(self.hessian_state, os.path.join(self.log_dir, "hessian.pt"))
if hasattr(self, "hessian_eigvec_state"):
torch.save(
self.hessian_eigvec_state,
os.path.join(self.log_dir, "hessian_eigvec.pt"),
)
if hasattr(self, "hessian_eigval_state"):
torch.save(
self.hessian_eigval_state,
os.path.join(self.log_dir, "hessian_eigval.pt"),
)
if hasattr(self, "ekfac_eigval_state"):
torch.save(
self.ekfac_eigval_state,
os.path.join(self.log_dir, "ekfac_eigval.pt"),
)
if hasattr(self, "hessian_inverse_state"):
torch.save(
self.hessian_inverse_state,
os.path.join(self.log_dir, "hessian_inverse.pt"),
)

def load_state(self, log_dir: str) -> None:
"""
Load Hessian state from disk.
"""
# TODO: implement this for all HessianHandlers
assert os.path.exists(log_dir), "Hessian log directory does not exist!"
log_dir_items = os.listdir(log_dir)
if "hessian.pt" in log_dir_items:
self.hessian_state = torch.load(os.path.join(log_dir, "hessian.pt"))
if "hessian_eigvec.pt" in log_dir_items:
self.hessian_eigvec_state = torch.load(
os.path.join(log_dir, "hessian_eigvec.pt")
)
if "hessian_eigval.pt" in log_dir_items:
self.hessian_eigval_state = torch.load(
os.path.join(log_dir, "hessian_eigval.pt")
)
if "ekfac_eigval.pt" in log_dir_items:
self.ekfac_eigval_state = torch.load(
os.path.join(log_dir, "ekfac_eigval.pt")
)
if "hessian_inverse.pt" in log_dir_items:
self.hessian_inverse_state = torch.load(
os.path.join(log_dir, "hessian_inverse.pt")
)
1 change: 1 addition & 0 deletions analog/hessian/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RawHessianHandler(HessianHandlerBase):
"""

def parse_config(self) -> None:
self.log_dir = self.config.get("log_dir")
self.damping = self.config.get("damping", 1e-2)
self.reduce = self.config.get("reduce", True)

Expand Down

0 comments on commit 0f44c4f

Please sign in to comment.