Skip to content

Commit

Permalink
Hessian is now logged in {log_root}/{project_name}/hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
YoungseogChung committed Nov 30, 2023
1 parent c8c4aee commit 74960b0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 27 deletions.
4 changes: 2 additions & 2 deletions analog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def _set_log_dir(self) -> None:
self._DEFAULTS["global"]["log_dir"] = log_dir
self._DEFAULTS["logging"]["log_dir"] = log_dir
self._DEFAULTS["storage"]["log_dir"] = log_dir
self._DEFAULTS["hessian"]["log_dir"] = log_dir
self._DEFAULTS["hessian"]["log_dir"] = log_dir + "/hessian"
else:
self.data["global"]["log_dir"] = log_dir
self.data["logging"]["log_dir"] = log_dir
self.data["storage"]["log_dir"] = log_dir
self.data["hessian"]["log_dir"] = log_dir
self.data["hessian"]["log_dir"] = log_dir + "/hessian"
3 changes: 0 additions & 3 deletions analog/hessian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ def __init__(self, config) -> None:
self.hessian_inverse_with_override = False
self.hessian_svd_with_override = False

# Logging
self.file_prefix = "hessian_log_"

self.parse_config()

@abstractmethod
Expand Down
43 changes: 21 additions & 22 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,61 +206,60 @@ def extract_activations(
assert mode == BACKWARD
return extract_backward_activations(data, module)

def save_state(self):
def save_state(self) -> None:
"""
Save Hessian state to disk.
"""
# TODO: should this be in the constructor or initialize-type function?
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

# TODO: implement this for all HessianHandlers
if hasattr(self, "hessian_state"):
torch.save(
self.hessian_state,
os.path.join(self.log_dir, self.file_prefix + "hessian.pt"),
)
torch.save(self.hessian_state, os.path.join(self.log_dir, "hessian.pt"))
if hasattr(self, "hessian_eigvec_state"):
torch.save(
self.hessian_eigvec_state,
os.path.join(self.log_dir, self.file_prefix + "hessian_eigvec.pt"),
os.path.join(self.log_dir, "hessian_eigvec.pt"),
)
if hasattr(self, "hessian_eigval_state"):
torch.save(
self.hessian_eigval_state,
os.path.join(self.log_dir, self.file_prefix + "hessian_eigval.pt"),
os.path.join(self.log_dir, "hessian_eigval.pt"),
)
if hasattr(self, "ekfac_eigval_state"):
torch.save(
self.ekfac_eigval_state,
os.path.join(self.log_dir, self.file_prefix + "ekfac_eigval.pt"),
os.path.join(self.log_dir, "ekfac_eigval.pt"),
)
if hasattr(self, "hessian_inverse_state"):
torch.save(
self.hessian_inverse_state,
os.path.join(self.log_dir, self.file_prefix + "hessian_inverse.pt"),
os.path.join(self.log_dir, "hessian_inverse.pt"),
)

def load_state(self, log_dir: str):
def load_state(self, log_dir: str) -> None:
"""
Load Hessian state from disk.
"""
# TODO: implement this for all HessianHandlers
assert os.path.exists(log_dir), "Hessian log directory does not exist!"
log_dir_items = os.listdir(log_dir)
if self.file_prefix + "hessian.pt" in log_dir_items:
self.hessian_state = torch.load(
os.path.join(log_dir, self.file_prefix + "hessian.pt")
)
if self.file_prefix + "hessian_eigvec.pt" in log_dir_items:
if "hessian.pt" in log_dir_items:
self.hessian_state = torch.load(os.path.join(log_dir, "hessian.pt"))
if "hessian_eigvec.pt" in log_dir_items:
self.hessian_eigvec_state = torch.load(
os.path.join(log_dir, self.file_prefix + "hessian_eigvec.pt")
os.path.join(log_dir, "hessian_eigvec.pt")
)
if self.file_prefix + "hessian_eigval.pt" in log_dir_items:
if "hessian_eigval.pt" in log_dir_items:
self.hessian_eigval_state = torch.load(
os.path.join(log_dir, self.file_prefix + "hessian_eigval.pt")
os.path.join(log_dir, "hessian_eigval.pt")
)
if self.file_prefix + "ekfac_eigval.pt" in log_dir_items:
if "ekfac_eigval.pt" in log_dir_items:
self.ekfac_eigval_state = torch.load(
os.path.join(log_dir, self.file_prefix + "ekfac_eigval.pt")
os.path.join(log_dir, "ekfac_eigval.pt")
)
if self.file_prefix + "hessian_inverse.pt" in log_dir_items:
if "hessian_inverse.pt" in log_dir_items:
self.hessian_inverse_state = torch.load(
os.path.join(log_dir, self.file_prefix + "hessian_inverse.pt")
os.path.join(log_dir, "hessian_inverse.pt")
)

0 comments on commit 74960b0

Please sign in to comment.