From 995c776f1c52186f975fea8eb62ccd0a29a0e584 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 6 Sep 2024 10:48:44 +0800 Subject: [PATCH] version compat --- deepmd/dpmodel/descriptor/dpa2.py | 8 +++++++- deepmd/pt/model/descriptor/dpa2.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 2e76b5424..2d2312c4a 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -927,7 +927,8 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptDPA2": data = data.copy() - check_version_compatibility(data.pop("@version"), 3, 3) + version = data.pop("@version") + check_version_compatibility(version, 3, 1) data.pop("@class") data.pop("type") repinit_variable = data.pop("repinit_variable").copy() @@ -941,6 +942,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA2": g1_shape_tranform = data.pop("g1_shape_tranform") tebd_transform = data.pop("tebd_transform", None) add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False data["repinit"] = RepinitArgs(**data.pop("repinit_args")) data["repformer"] = RepformerArgs(**data.pop("repformer_args")) # compat with version 1 diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 8c3ae36cf..cabbdae17 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -603,7 +603,8 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptDPA2": data = data.copy() - check_version_compatibility(data.pop("@version"), 3, 3) + version = data.pop("@version") + check_version_compatibility(version, 3, 1) data.pop("@class") data.pop("type") repinit_variable = data.pop("repinit_variable").copy() @@ -617,6 +618,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA2": g1_shape_tranform = data.pop("g1_shape_tranform") tebd_transform = data.pop("tebd_transform", None) add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False data["repinit"] = RepinitArgs(**data.pop("repinit_args")) data["repformer"] = RepformerArgs(**data.pop("repformer_args")) # compat with version 1