Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 27, 2023
1 parent d9e8695 commit 7440dda
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def build_storage_handler(self) -> None:
storage_type = storage_config.get("type", "default")
if storage_type == "default":
return DefaultStorageHandler(storage_config)
elif storage_type == "mongodb":
return MongoDBStorageHandler(storage_config)
else:
raise ValueError(f"Unknown storage type: {storage_type}")

Expand Down Expand Up @@ -342,6 +340,12 @@ def get_hessian_state(
return hessian_state.copy()
return hessian_state

def get_hessian_svd_state(self) -> Dict[str, Dict[str, torch.Tensor]]:
"""
Returns the SVD of the Hessian from the Hessian handler.
"""
return self.hessian_handler.get_hessian_svd_state()

def hessian_inverse(self):
"""
Compute the inverse of the Hessian.
Expand Down
6 changes: 3 additions & 3 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def update_hessian(
)
else:
self.hessian_state[module_name][mode].addmm_(activation.t(), activation)
self.sample_counter[module_name][mode] += self.get_sample_size(data, mask)
self.sample_counter[module_name][mode] += len(data)

@torch.no_grad()
def update_ekfac(
Expand Down Expand Up @@ -173,12 +173,12 @@ def synchronize(self) -> None:

if self.ekfac:
for _, ekfac_eigval in self.ekfac_eigval_state.items():
ekfac_eigval.div_(world_size)
ekfac_eigval.div_(get_world_size())
dist.all_reduce(ekfac_eigval, op=dist.ReduceOp.SUM)
else:
for _, module_state in self.hessian_state.items():
for _, covariance in module_state.items():
covariance.div_(world_size)
covariance.div_(get_world_size())
dist.all_reduce(covariance, op=dist.ReduceOp.SUM)

def clear(self) -> None:
Expand Down

0 comments on commit 7440dda

Please sign in to comment.