Skip to content

Commit

Permalink
Chore: refactor InvarFitting (#3266)
Browse files Browse the repository at this point in the history
Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Feb 16, 2024
1 parent 43f17da commit 15f8d25
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 277 deletions.
6 changes: 4 additions & 2 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,16 @@ def _compute_weight(
),
axis=-1,
) # handle masked nnei.
sigma = numerator / denominator
with np.errstate(divide="ignore", invalid="ignore"):
sigma = numerator / denominator
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
coef = np.zeros_like(u)
left_mask = sigma < self.sw_rmin
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
right_mask = sigma >= self.sw_rmax
coef[left_mask] = 1
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
with np.errstate(invalid="ignore"):
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
self.zbl_weight = coef
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DenoiseNet,
)
from .dipole import (
DipoleFittingNetType,
DipoleFittingNet,
)
from .ener import (
EnergyFittingNet,
Expand All @@ -25,7 +25,7 @@
__all__ = [
"FittingNetAttenLcc",
"DenoiseNet",
"DipoleFittingNetType",
"DipoleFittingNet",
"EnergyFittingNet",
"EnergyFittingNetDirect",
"Fitting",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
log = logging.getLogger(__name__)


class DipoleFittingNetType(Fitting):
class DipoleFittingNet(Fitting):
def __init__(
self, ntypes, embedding_width, neuron, out_dim, resnet_dt=True, **kwargs
):
Expand Down
Loading

0 comments on commit 15f8d25

Please sign in to comment.