From e4ae74bf949fc35562ffb22809622eebc43fac37 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 19 Jun 2024 10:46:17 -0400 Subject: [PATCH] [r2] fix seeds in se_a and se_atten (#3880) Fix #3799. - **New Features** - Introduced flexibility in specifying seed values, allowing either an integer or a list of integers. - Enhanced seed parameter usage across various initialization methods and classes for more controlled randomization. - **Improvements** - Updated seed initialization logic to include additional computations and dynamic adjustments. - Enhanced documentation for parameters in multiple classes, providing clearer usage guidelines. --------- Signed-off-by: Jinzhe Zeng (cherry picked from commit 0c472d1596ae24b5b548ae6ae38688dbab911de5) Signed-off-by: Jinzhe Zeng --- deepmd/descriptor/se_a.py | 7 +++- deepmd/descriptor/se_atten.py | 16 +++++++- source/tests/test_model_se_a_ebd_v2.py | 56 +++++++++++++------------- source/tests/test_pairwise_dprc.py | 4 +- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index f3e62d3672..219f6d2145 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -1031,6 +1031,8 @@ def _filter_lower( mixed_prec=self.mixed_prec, ) net_output = tf.nn.embedding_lookup(net_output, idx) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift net_output = tf.reshape(net_output, [-1, self.filter_neuron[-1]]) else: xyz_scatter = self._concat_type_embedding( @@ -1042,7 +1044,7 @@ def _filter_lower( ) # natom x 4 x outputs_size if nvnmd_cfg.enable: - return filter_lower_R42GR( + oo = filter_lower_R42GR( type_i, type_input, inputs_i, @@ -1060,6 +1062,9 @@ def _filter_lower( self.filter_resnet_dt, self.embedding_net_variables, ) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift + return oo if self.compress and (not is_exclude): if self.stripped_type_embedding: net_output = tf.nn.embedding_lookup( diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 5615863254..5acd525e55 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -959,6 +959,8 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) + if not self.uniform_seed and self.seed is not None: + self.seed += 1 K_c = one_layer( input_xyz, self.att_n, @@ -972,6 +974,8 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) + if not self.uniform_seed and self.seed is not None: + self.seed += 1 V_c = one_layer( input_xyz, self.att_n, @@ -985,6 +989,8 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) + if not self.uniform_seed and self.seed is not None: + self.seed += 1 # # natom x nei_type_i x out_size # xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, outputs_size[-1])) # natom x nei_type_i x att_n @@ -1017,6 +1023,8 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) + if not self.uniform_seed and self.seed is not None: + self.seed += 1 input_xyz = tf.keras.layers.LayerNormalization( beta_initializer=tf.constant_initializer(self.beta[i]), gamma_initializer=tf.constant_initializer(self.gamma[i]), @@ -1080,6 +1088,8 @@ def _filter_lower( initial_variables=self.embedding_net_variables, mixed_prec=self.mixed_prec, ) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift else: if self.attn_layer == 0: log.info( @@ -1119,6 +1129,8 @@ def _filter_lower( initial_variables=self.embedding_net_variables, mixed_prec=self.mixed_prec, ) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift else: net = "filter_net" info = [ @@ -1176,6 +1188,8 @@ def _filter_lower( initial_variables=self.two_side_embeeding_net_variables, mixed_prec=self.mixed_prec, ) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift two_embd = tf.nn.embedding_lookup( embedding_of_two_side_type_embedding, index_of_two_side ) @@ -1194,8 +1208,6 @@ def _filter_lower( is_sorted=len(self.exclude_types) == 0, ) - if (not self.uniform_seed) and (self.seed is not None): - self.seed += self.seed_shift input_r = tf.slice( tf.reshape(inputs_i, (-1, shape_i[1] // 4, 4)), [0, 0, 1], [-1, -1, 3] ) diff --git a/source/tests/test_model_se_a_ebd_v2.py b/source/tests/test_model_se_a_ebd_v2.py index 71860890ce..0cd74b5f76 100644 --- a/source/tests/test_model_se_a_ebd_v2.py +++ b/source/tests/test_model_se_a_ebd_v2.py @@ -139,37 +139,37 @@ def test_model(self): f = f.reshape([-1]) v = v.reshape([-1]) - refe = [5.435394596262052014e-01] + refe = [6.100037044296185e-01] reff = [ - 6.583728125594628944e-02, - 7.228993116083935744e-02, - 1.971543579114074483e-03, - 6.567474563776359853e-02, - 7.809421727465599983e-02, - -4.866958849094786890e-03, - -8.670511901715304004e-02, - 3.525374157021862048e-02, - 1.415748959800727487e-03, - 6.375813001810648473e-02, - -1.139053242798149790e-01, - -4.178593754384440744e-03, - -1.471737787218250215e-01, - 4.189712704724830872e-02, - 7.011731363309440038e-03, - 3.860874082716164030e-02, - -1.136296927731473005e-01, - -1.353471298745012206e-03, + 8.448651008616304e-02, + 8.613568658155157e-02, + 4.377711655236228e-03, + 9.264613309788312e-02, + 9.351200240060925e-02, + -6.743918515275118e-03, + -1.268078358219972e-01, + 4.855965861982662e-02, + 1.361334787979757e-04, + 4.193213089916692e-02, + -1.324120032345251e-01, + -4.507320444374342e-03, + -1.314595297986654e-01, + 4.120567370248839e-02, + 7.896917575801866e-03, + 3.920259153744955e-02, + -1.370010180699507e-01, + -1.159523750186610e-03, ] refv = [ - -4.243979601186427253e-01, - 1.097173849143971286e-01, - 1.227299373463585502e-02, - 1.097173849143970314e-01, - -2.462891443164323124e-01, - -5.711664180530139426e-03, - 1.227299373463585502e-02, - -5.711664180530143763e-03, - -6.217348853341628408e-04, + -0.277134219204478, + 0.088897922530779, + 0.008633318264458, + 0.088897922530779, + -0.292191560546969, + -0.005709595520904, + 0.008633318264458, + -0.005709595520904, + -0.000682136341924, ] refe = np.reshape(refe, [-1]) reff = np.reshape(reff, [-1]) diff --git a/source/tests/test_pairwise_dprc.py b/source/tests/test_pairwise_dprc.py index 8544193407..6b5d01f1b5 100644 --- a/source/tests/test_pairwise_dprc.py +++ b/source/tests/test_pairwise_dprc.py @@ -519,8 +519,8 @@ def test_model_ener(self): # the model is pairwise! self.assertAllClose(e[1] + e[2] + e[3] - 3 * e[0], e[4] - e[0]) self.assertAllClose(f[1] + f[2] + f[3] - 3 * f[0], f[4] - f[0]) - self.assertAllClose(e[0], 0.189075, 1e-6) - self.assertAllClose(f[0, 0], 0.060047, 1e-6) + self.assertAllClose(e[0], 4.82969, 1e-6) + self.assertAllClose(f[0, 0], -0.104339, 1e-6) def test_nloc(self): jfile = tests_path / "pairwise_dprc.json"