Skip to content

Commit

Permalink
support KF for denoising
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <liuxin_zijian@163.com>
  • Loading branch information
zjgemi committed Apr 27, 2023
1 parent bbb9114 commit 8545a1f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 13 deletions.
35 changes: 35 additions & 0 deletions deepmd_pt/optimizer/KFWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,41 @@ def update_force(
self.optimizer.step(error)
return Etot_predict, force_predict

def update_denoise_coord(
self, inputs: dict, clean_coord: torch.Tensor, update_prefactor: float = 1, mask_loss_coord: bool = True, coord_mask: torch.Tensor = None
) -> None:
natoms_sum = inputs['natoms'][0, 0]
bs = clean_coord.shape[0]
self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3)

index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum)

for i in range(index.shape[0]):
self.optimizer.zero_grad()
model_pred, _, _ = self.model(**inputs, inference_only=True)
updated_coord = model_pred['updated_coord']
natoms_sum = inputs['natoms'][0, 0]
error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]]
error_tmp = update_prefactor * error_tmp
if mask_loss_coord:
error_tmp[~coord_mask[:, index[i]]] = 0
mask = error_tmp < 0
error_tmp[mask] = -1 * error_tmp[mask]
error = error_tmp.mean() / natoms_sum

if self.is_distributed:
dist.all_reduce(error)
error /= dist.get_world_size()

tmp_coord_predict = updated_coord[:, index[i]] * update_prefactor
tmp_coord_predict[mask] = -update_prefactor * tmp_coord_predict[mask]

# In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259
(tmp_coord_predict.sum() + updated_coord.sum() * 0).backward()
error = error * math.sqrt(bs)
self.optimizer.step(error)
return model_pred

def __sample(
self, atoms_selected: int, atoms_per_group: int, natoms: int
) -> np.ndarray:
Expand Down
3 changes: 2 additions & 1 deletion deepmd_pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init_P(self):
for param in params:
param_num = param.data.nelement()
if param_sum + param_num > block_size:
param_nums.append(param_sum)
if param_sum > 0:
param_nums.append(param_sum)
param_sum = param_num
else:
param_sum += param_num
Expand Down
33 changes: 21 additions & 12 deletions deepmd_pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,20 +199,29 @@ def step(_step_id, task_key="Default"):
self.optimizer.step()
self.scheduler.step()
elif self.opt_type == "LKF":
KFOptWrapper = KFOptimizerWrapper(
self.wrapper, self.optimizer, 24, 6, dist.is_initialized()
)
_ = KFOptWrapper.update_energy(input_dict, label_dict["energy"])
p_energy, p_force = KFOptWrapper.update_force(
input_dict, label_dict["force"]
)
# [coord, atype, natoms, mapping, shift, selected, box]
model_pred = {"energy": p_energy, "force": p_force}
module = self.wrapper.module if dist.is_initialized() else self.wrapper
loss, more_loss = module.loss[task_key](
if isinstance(self.loss, EnergyStdLoss):
KFOptWrapper = KFOptimizerWrapper(
self.wrapper, self.optimizer, 24, 6, dist.is_initialized()
)
_ = KFOptWrapper.update_energy(input_dict, label_dict["energy"])
p_energy, p_force = KFOptWrapper.update_force(
input_dict, label_dict["force"]
)
# [coord, atype, natoms, mapping, shift, selected, box]
model_pred = {"energy": p_energy, "force": p_force}
module = self.wrapper.module if dist.is_initialized() else self.wrapper
loss, more_loss = module.loss[task_key](
model_pred, label_dict, input_dict["natoms"], learning_rate=cur_lr
)
elif isinstance(self.loss, DenoiseLoss):
KFOptWrapper = KFOptimizerWrapper(
self.wrapper, self.optimizer, 24, 6, dist.is_initialized()
)
module = self.wrapper.module if dist.is_initialized() else self.wrapper
model_pred = KFOptWrapper.update_denoise_coord(input_dict, label_dict["clean_coord"], 1, module.loss[task_key].mask_loss_coord, label_dict["coord_mask"])
loss, more_loss = module.loss[task_key](
model_pred, label_dict, input_dict["natoms"], learning_rate=cur_lr
)

else:
raise ValueError("Not supported optimizer type '%s'" % self.opt_type)

Expand Down

0 comments on commit 8545a1f

Please sign in to comment.