Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Feb 23, 2024
1 parent 55bcf50 commit 80c78a9
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def get_class_by_type(cls, model_type: str):
PairwiseDPRc,
)

model_type = input.get("type", "standard")
if model_type == "standard":
return StandardModel
elif model_type == "multi":
Expand All @@ -139,7 +138,7 @@ def get_class_by_type(cls, model_type: str):
def __new__(cls, *args, **kwargs):
if cls is Model:
# init model
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
return cls.__new__(cls, *args, **kwargs)
return super().__new__(cls)

Expand Down Expand Up @@ -578,7 +577,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
dict
The updated local data
"""
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
return cls.update_sel(global_jdata, local_jdata)

@classmethod
Expand All @@ -602,7 +601,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
"""
if cls is Descriptor:
return Descriptor.get_class_by_type(
j_get_type(data, cls.__name__)
data.get("type", "standard")
).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Expand Down

0 comments on commit 80c78a9

Please sign in to comment.