From 8bd195379446aed2fe294f69be6729c4b3c16040 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:02:30 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/descriptor/se_a.py | 8 ++++++-- deepmd/descriptor/se_atten.py | 8 ++++++-- deepmd/utils/graph.py | 13 +++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 920619593c..1cbbcc2c51 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -41,9 +41,9 @@ GraphWithoutTensorError, ) from deepmd.utils.graph import ( + get_extra_embedding_net_variables_from_graph_def, get_pattern_nodes_from_graph_def, get_tensor_by_name_from_graph, - get_extra_embedding_net_variables_from_graph_def, ) from deepmd.utils.network import ( embedding_net, @@ -1333,7 +1333,11 @@ def init_variables( extra_suffix = "_one_side_ebd" else: extra_suffix = "_two_side_ebd" - self.extra_embedding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, extra_suffix, self.layer_size) + self.extra_embedding_net_variables = ( + get_extra_embedding_net_variables_from_graph_def( + graph_def, suffix, extra_suffix, self.layer_size + ) + ) @property def explicit_ntypes(self) -> bool: diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 4cef308ff1..1d1f2c7fa6 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -42,9 +42,9 @@ ) from deepmd.utils.graph import ( get_attention_layer_variables_from_graph_def, + get_extra_embedding_net_variables_from_graph_def, get_pattern_nodes_from_graph_def, get_tensor_by_name_from_graph, - get_extra_embedding_net_variables_from_graph_def, ) from deepmd.utils.network import ( embedding_net, @@ -1311,7 +1311,11 @@ def init_variables( ] if self.stripped_type_embedding: - self.two_side_embeeding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, "_two_side_ebd", self.layer_size) + self.two_side_embeeding_net_variables = ( + get_extra_embedding_net_variables_from_graph_def( + graph_def, suffix, "_two_side_ebd", self.layer_size + ) + ) def build_type_exclude_mask( self, diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index 20fb7b16c5..86d7d6e2c6 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -236,6 +236,7 @@ def get_embedding_net_variables_from_graph_def( embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape) return embedding_net_variables + def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str): """Get variables from the given tf.GraphDef object, with numpy array returns. @@ -243,7 +244,7 @@ def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: ---------- graph_def The input tf.GraphDef object - suffix : str + pattern : str The name of variable Returns @@ -263,7 +264,10 @@ def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: tensor_value = get_tensor_by_type(node, dtype) return np.reshape(tensor_value, tensor_shape) -def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int): + +def get_extra_embedding_net_variables_from_graph_def( + graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int +): """Get extra embedding net variables from the given tf.GraphDef object. The "extra embedding net" means the embedding net with only type embeddings input, which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor. @@ -274,10 +278,10 @@ def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suf The input tf.GraphDef object suffix : str The "common" suffix in the descriptor - extra_suffix: str + extra_suffix : str This value depends on the value of "type_one_side". It should always be "_one_side_ebd" or "_two_side_ebd" - layer_size: int + layer_size : int The layer size of the embedding net Returns @@ -297,6 +301,7 @@ def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suf ] = get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern) return extra_embedding_net_variables + def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict: """Get the embedding net variables with the given frozen model(model_file).