diff --git a/deepmd_pt/optimizer/KFWrapper.py b/deepmd_pt/optimizer/KFWrapper.py index 1ce173ab..d204049c 100644 --- a/deepmd_pt/optimizer/KFWrapper.py +++ b/deepmd_pt/optimizer/KFWrapper.py @@ -85,6 +85,41 @@ def update_force( self.optimizer.step(error) return Etot_predict, force_predict + def update_denoise_coord( + self, inputs: dict, clean_coord: torch.Tensor, update_prefactor: float = 1, mask_loss_coord: bool = True, coord_mask: torch.Tensor = None + ) -> None: + natoms_sum = inputs['natoms'][0, 0] + bs = clean_coord.shape[0] + self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3) + + index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum) + + for i in range(index.shape[0]): + self.optimizer.zero_grad() + model_pred, _, _ = self.model(**inputs, inference_only=True) + updated_coord = model_pred['updated_coord'] + natoms_sum = inputs['natoms'][0, 0] + error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]] + error_tmp = update_prefactor * error_tmp + if mask_loss_coord: + error_tmp[~coord_mask[:, index[i]]] = 0 + mask = error_tmp < 0 + error_tmp[mask] = -1 * error_tmp[mask] + error = error_tmp.mean() / natoms_sum + + if self.is_distributed: + dist.all_reduce(error) + error /= dist.get_world_size() + + tmp_coord_predict = updated_coord[:, index[i]] * update_prefactor + tmp_coord_predict[mask] = -update_prefactor * tmp_coord_predict[mask] + + # In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259 + (tmp_coord_predict.sum() + updated_coord.sum() * 0).backward() + error = error * math.sqrt(bs) + self.optimizer.step(error) + return model_pred + def __sample( self, atoms_selected: int, atoms_per_group: int, natoms: int ) -> np.ndarray: diff --git a/deepmd_pt/optimizer/LKF.py b/deepmd_pt/optimizer/LKF.py index 42230730..e246ae31 100644 --- a/deepmd_pt/optimizer/LKF.py +++ b/deepmd_pt/optimizer/LKF.py @@ -48,7 +48,8 @@ def __init_P(self): for param in params: param_num = param.data.nelement() if param_sum + param_num > block_size: - param_nums.append(param_sum) + if param_sum > 0: + param_nums.append(param_sum) param_sum = param_num else: param_sum += param_num diff --git a/deepmd_pt/train/training.py b/deepmd_pt/train/training.py index 27e3e7b4..27ae4308 100644 --- a/deepmd_pt/train/training.py +++ b/deepmd_pt/train/training.py @@ -199,20 +199,29 @@ def step(_step_id, task_key="Default"): self.optimizer.step() self.scheduler.step() elif self.opt_type == "LKF": - KFOptWrapper = KFOptimizerWrapper( - self.wrapper, self.optimizer, 24, 6, dist.is_initialized() - ) - _ = KFOptWrapper.update_energy(input_dict, label_dict["energy"]) - p_energy, p_force = KFOptWrapper.update_force( - input_dict, label_dict["force"] - ) - # [coord, atype, natoms, mapping, shift, selected, box] - model_pred = {"energy": p_energy, "force": p_force} - module = self.wrapper.module if dist.is_initialized() else self.wrapper - loss, more_loss = module.loss[task_key]( + if isinstance(self.loss, EnergyStdLoss): + KFOptWrapper = KFOptimizerWrapper( + self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + ) + _ = KFOptWrapper.update_energy(input_dict, label_dict["energy"]) + p_energy, p_force = KFOptWrapper.update_force( + input_dict, label_dict["force"] + ) + # [coord, atype, natoms, mapping, shift, selected, box] + model_pred = {"energy": p_energy, "force": p_force} + module = self.wrapper.module if dist.is_initialized() else self.wrapper + loss, more_loss = module.loss[task_key]( + model_pred, label_dict, input_dict["natoms"], learning_rate=cur_lr + ) + elif isinstance(self.loss, DenoiseLoss): + KFOptWrapper = KFOptimizerWrapper( + self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + ) + module = self.wrapper.module if dist.is_initialized() else self.wrapper + model_pred = KFOptWrapper.update_denoise_coord(input_dict, label_dict["clean_coord"], 1, module.loss[task_key].mask_loss_coord, label_dict["coord_mask"]) + loss, more_loss = module.loss[task_key]( model_pred, label_dict, input_dict["natoms"], learning_rate=cur_lr ) - else: raise ValueError("Not supported optimizer type '%s'" % self.opt_type)