Skip to content

Commit

Permalink
se_e3
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed May 30, 2024
1 parent af82658 commit ecb7b66
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
28 changes: 24 additions & 4 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -348,15 +351,32 @@ def deserialize(cls, data: dict) -> "DescrptSeT":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 379 in deepmd/dpmodel/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L379

Added line #L379 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 382 in deepmd/dpmodel/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L382

Added line #L382 was not covered by tests
28 changes: 24 additions & 4 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -324,18 +327,35 @@ def t_cvt(xx):
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 355 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L355

Added line #L355 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 358 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L358

Added line #L358 was not covered by tests


@DescriptorBlock.register("se_e3")
Expand Down

0 comments on commit ecb7b66

Please sign in to comment.