Skip to content

Commit

Permalink
feat: support seed for pt/dp models (deepmodeling#3773)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added various neural network weight initialization methods: zeros,
ones, constants, normal distributions, truncated normal distributions,
Kaiming distributions, and Xavier distributions.

- **Improvements**
- Included optional `seed` parameter in initialization methods and
classes.
  - Implemented a `get_generator` function for random seed management.

- **Bug Fixes**
- Addressed potential unintended behavior by ensuring proper random seed
setting during training processes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
  • Loading branch information
iProzd authored and mtaillefumier committed Sep 18, 2024
1 parent d77521a commit d393280
Show file tree
Hide file tree
Showing 17 changed files with 658 additions and 48 deletions.
9 changes: 8 additions & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ def __init__(
activation_function: Optional[str] = None,
resnet: bool = False,
precision: str = DEFAULT_PRECISION,
seed: Optional[int] = None,
) -> None:
prec = PRECISION_DICT[precision.lower()]
self.precision = precision
# only use_timestep when skip connection is established.
use_timestep = use_timestep and (num_out == num_in or num_out == num_in * 2)
rng = np.random.default_rng()
rng = np.random.default_rng(seed)
self.w = rng.normal(size=(num_in, num_out)).astype(prec)
self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None
self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None
Expand Down Expand Up @@ -313,6 +314,7 @@ def __init__(
uni_init: bool = True,
trainable: bool = True,
precision: str = DEFAULT_PRECISION,
seed: Optional[int] = None,
) -> None:
self.eps = eps
self.uni_init = uni_init
Expand All @@ -325,6 +327,7 @@ def __init__(
activation_function=None,
resnet=False,
precision=precision,
seed=seed,
)
self.w = self.w.squeeze(0) # keep the weight shape to be [num_in]
if self.uni_init:
Expand Down Expand Up @@ -569,6 +572,7 @@ def __init__(
activation_function: str = "tanh",
resnet_dt: bool = False,
precision: str = DEFAULT_PRECISION,
seed: Optional[int] = None,
):
layers = []
i_in = in_dim
Expand All @@ -583,6 +587,7 @@ def __init__(
activation_function=activation_function,
resnet=True,
precision=precision,
seed=seed,
).serialize()
)
i_in = i_ot
Expand Down Expand Up @@ -669,6 +674,7 @@ def __init__(
resnet_dt: bool = False,
precision: str = DEFAULT_PRECISION,
bias_out: bool = True,
seed: Optional[int] = None,
):
super().__init__(
in_dim,
Expand All @@ -688,6 +694,7 @@ def __init__(
activation_function=None,
resnet=False,
precision=precision,
seed=seed,
)
)
self.out_dim = out_dim
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
seed: int, Optional
Random seed for parameter initialization.
use_econf_tebd: bool, Optional
Whether to use electronic configuration type embedding.
type_map: List[str], Optional
Expand Down Expand Up @@ -225,12 +227,12 @@ def __init__(
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
seed: Optional[int] = None,
use_econf_tebd: bool = False,
type_map: Optional[List[str]] = None,
# not implemented
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
Expand Down Expand Up @@ -275,6 +277,7 @@ def __init__(
env_protection=env_protection,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
seed=seed,
old_impl=old_impl,
)
self.use_econf_tebd = use_econf_tebd
Expand All @@ -283,6 +286,7 @@ def __init__(
ntypes,
tebd_dim,
precision=precision,
seed=seed,
use_econf_tebd=use_econf_tebd,
type_map=type_map,
)
Expand Down
7 changes: 6 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
trainable : bool, optional
If the parameters are trainable.
seed : int, optional
(Unused yet) Random seed for parameter initialization.
Random seed for parameter initialization.
add_tebd_to_repinit_out : bool, optional
Whether to add type embedding to the output representation from repinit before inputting it into repformer.
use_econf_tebd : bool, Optional
Expand Down Expand Up @@ -160,6 +160,7 @@ def init_subclass_params(sub_data, sub_class):
resnet_dt=self.repinit_args.resnet_dt,
smooth=smooth,
type_one_side=self.repinit_args.type_one_side,
seed=seed,
)
self.repformers = DescrptBlockRepformers(
self.repformer_args.rcut,
Expand Down Expand Up @@ -194,6 +195,7 @@ def init_subclass_params(sub_data, sub_class):
precision=precision,
trainable_ln=self.repformer_args.trainable_ln,
ln_eps=self.repformer_args.ln_eps,
seed=seed,
old_impl=old_impl,
)
self.use_econf_tebd = use_econf_tebd
Expand All @@ -202,6 +204,7 @@ def init_subclass_params(sub_data, sub_class):
ntypes,
self.repinit_args.tebd_dim,
precision=precision,
seed=seed,
use_econf_tebd=self.use_econf_tebd,
type_map=type_map,
)
Expand All @@ -222,6 +225,7 @@ def init_subclass_params(sub_data, sub_class):
bias=False,
precision=precision,
init="glorot",
seed=seed,
)
self.tebd_transform = None
if self.add_tebd_to_repinit_out:
Expand All @@ -230,6 +234,7 @@ def init_subclass_params(sub_data, sub_class):
self.repformers.dim_in,
bias=False,
precision=precision,
seed=seed,
)
assert self.repinit.rcut > self.repformers.rcut
assert self.repinit.sel[0] > self.repformers.sel[0]
Expand Down
Loading

0 comments on commit d393280

Please sign in to comment.