Skip to content

Commit

Permalink
[r2] fix: move resnet_dt checking from graph to tabulate
Browse files Browse the repository at this point in the history
Fix deepmodeling#3950. Backport a part of deepmodeling#3263 to r2 (the whole of deepmodeling#3263 is not likely to be backported).

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Jul 5, 2024
1 parent 84ca63c commit 897376a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 0 additions & 4 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def get_embedding_net_nodes_from_graph_def(
embedding_net_nodes = get_pattern_nodes_from_graph_def(
graph_def, embedding_net_pattern
)
for key in embedding_net_nodes.keys():
assert (
key.find("bias") > 0 or key.find("matrix") > 0
), "currently, only support weight matrix and bias matrix at the tabulation op!"
return embedding_net_nodes


Expand Down
4 changes: 4 additions & 0 deletions deepmd/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __init__(
self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(
self.graph_def, suffix=self.suffix
)
for key in self.embedding_net_nodes.keys():
assert (
key.find("bias") > 0 or key.find("matrix") > 0
), "currently, only support weight matrix and bias matrix at the tabulation op!"

# move it to the descriptor class
# for tt in self.exclude_types:
Expand Down

0 comments on commit 897376a

Please sign in to comment.