Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 22, 2023
1 parent cc9c470 commit 8bd1953
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
8 changes: 6 additions & 2 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
GraphWithoutTensorError,
)
from deepmd.utils.graph import (
get_extra_embedding_net_variables_from_graph_def,
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
get_extra_embedding_net_variables_from_graph_def,
)
from deepmd.utils.network import (
embedding_net,
Expand Down Expand Up @@ -1333,7 +1333,11 @@ def init_variables(
extra_suffix = "_one_side_ebd"

Check warning on line 1333 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1333

Added line #L1333 was not covered by tests
else:
extra_suffix = "_two_side_ebd"
self.extra_embedding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, extra_suffix, self.layer_size)
self.extra_embedding_net_variables = (
get_extra_embedding_net_variables_from_graph_def(
graph_def, suffix, extra_suffix, self.layer_size
)
)

@property
def explicit_ntypes(self) -> bool:
Expand Down
8 changes: 6 additions & 2 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
)
from deepmd.utils.graph import (
get_attention_layer_variables_from_graph_def,
get_extra_embedding_net_variables_from_graph_def,
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
get_extra_embedding_net_variables_from_graph_def,
)
from deepmd.utils.network import (
embedding_net,
Expand Down Expand Up @@ -1311,7 +1311,11 @@ def init_variables(
]

if self.stripped_type_embedding:
self.two_side_embeeding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, "_two_side_ebd", self.layer_size)
self.two_side_embeeding_net_variables = (
get_extra_embedding_net_variables_from_graph_def(
graph_def, suffix, "_two_side_ebd", self.layer_size
)
)

def build_type_exclude_mask(
self,
Expand Down
13 changes: 9 additions & 4 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,15 @@ def get_embedding_net_variables_from_graph_def(
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return embedding_net_variables


def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str):
"""Get variables from the given tf.GraphDef object, with numpy array returns.
Parameters
----------
graph_def
The input tf.GraphDef object
suffix : str
pattern : str
The name of variable
Returns
Expand All @@ -263,7 +264,10 @@ def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern:
tensor_value = get_tensor_by_type(node, dtype)

Check warning on line 264 in deepmd/utils/graph.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/graph.py#L264

Added line #L264 was not covered by tests
return np.reshape(tensor_value, tensor_shape)

def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int):

def get_extra_embedding_net_variables_from_graph_def(
graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int
):
"""Get extra embedding net variables from the given tf.GraphDef object.
The "extra embedding net" means the embedding net with only type embeddings input,
which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor.
Expand All @@ -274,10 +278,10 @@ def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suf
The input tf.GraphDef object
suffix : str
The "common" suffix in the descriptor
extra_suffix: str
extra_suffix : str
This value depends on the value of "type_one_side".
It should always be "_one_side_ebd" or "_two_side_ebd"
layer_size: int
layer_size : int
The layer size of the embedding net
Returns
Expand All @@ -297,6 +301,7 @@ def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suf
] = get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern)
return extra_embedding_net_variables


def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict:
"""Get the embedding net variables with the given frozen model(model_file).
Expand Down

0 comments on commit 8bd1953

Please sign in to comment.