Skip to content

Commit

Permalink
fix float precision problem of se_atten in line 217 (#3961) (#3978)
Browse files Browse the repository at this point in the history
fix float precision problem of se_atten in line 217.
fix the bug: the different energy between qnn and lammps

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Improved energy calculation methods for more accurate results in the
`wrap` module.
- Introduced new parameters for enhanced configurability in
energy-related computations.

- **Improvements**
- Enhanced handling and processing of energy shift arrays for better
performance and accuracy.
- Updated array manipulation and calculation methods for various
wrapping functionalities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: LiuGroupHNU <liujie123@HNU>
Co-authored-by: MoPinghui <mopinghui1020@gmail.com>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pinghui Mo <pinghui_mo@outlook.com>
  • Loading branch information
6 people authored Jul 18, 2024
1 parent 24d151a commit 6199b03
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deepmd/tf/nvnmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def filter_lower_R42GR(inputs_i, atype, nei_type_vec):
NIDP = nvnmd_cfg.dscp["NIDP"]
two_embd_value = nvnmd_cfg.map["gt"]
# print(two_embd_value)

two_embd_value = GLOBAL_NP_FLOAT_PRECISION(two_embd_value)
# copy
inputs_reshape = op_module.flt_nvnmd(inputs_reshape)
inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4])
Expand Down
18 changes: 12 additions & 6 deletions deepmd/tf/nvnmd/entrypoints/wrap.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def wrap_head(self, nhs, nws):
ntype number of atomic species
nnei number of neighbors
atom_ener atom bias energy
ener_fact factor for atom_ener
"""
nbit = nvnmd_cfg.nbit
ctrl = nvnmd_cfg.ctrl
Expand Down Expand Up @@ -209,19 +210,26 @@ def wrap_head(self, nhs, nws):
atom_ener = weight["t_bias_atom_e"]
else:
atom_ener = [0] * 32
atom_ener_shift = []
nlayer_fit = fitn["nlayer_fit"]
if VERSION == 0:
for tt in range(ntype):
w, b, _idt = get_fitnet_weight(weight, tt, nlayer_fit - 1, nlayer_fit)
shift = atom_ener[tt] + b[0]
SHIFT = e.qr(shift, NBIT_FIXD_FL)
bs = e.dec2bin(SHIFT, NBIT_MODEL_HEAD, signed=True)[0] + bs
atom_ener_shift.append(shift)
if VERSION == 1:
for tt in range(ntype):
w, b, _idt = get_fitnet_weight(weight, 0, nlayer_fit - 1, nlayer_fit)
shift = atom_ener[tt] + b[0]
SHIFT = e.qr(shift, NBIT_FIXD_FL)
bs = e.dec2bin(SHIFT, NBIT_MODEL_HEAD, signed=True)[0] + bs
atom_ener_shift.append(shift)
atom_ener_shift = np.array(atom_ener_shift)
max_ea = np.ceil(np.log2(np.max(np.abs(atom_ener_shift))))
max_ea = np.max([max_ea + NBIT_FIXD_FL - NBIT_MODEL_HEAD + 1, 0])
atom_ener_shift = atom_ener_shift / 2**max_ea
for shift in atom_ener_shift:
SHIFT = e.qr(shift, NBIT_FIXD_FL)
bs = e.dec2bin(SHIFT, NBIT_MODEL_HEAD, signed=True)[0] + bs
bs = e.dec2bin(max_ea, NBIT_MODEL_HEAD, signed=True)[0] + bs
# extend
hs = e.bin2hex(bs)
hs = e.extend_hex(hs, NBIT_MODEL_HEAD * nhead)
Expand Down Expand Up @@ -592,8 +600,6 @@ def wrap_lut(self):
_d = d[ii]
_d = np.reshape(_d, [1, -1])
_d = np.matmul(_d, w)
# _d = np.reshape(_d, [-1, 2])
# _d = np.concatenate([_d[:,0], _d[:,1]], axis=0)
d2[ii] = _d
d2 = e.qr(d2, NBIT_DATA_FL)
bavc = e.dec2bin(d2, NBIT_WXDB, True)
Expand Down

0 comments on commit 6199b03

Please sign in to comment.