Skip to content

Commit

Permalink
fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 22, 2025
1 parent 64030f5 commit 278f51d
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,70 @@ def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return True

@classmethod
def deserialize(cls, data):
"""Deserialize the descriptor block."""
data = data.copy()
edge_embd = NativeLayer.deserialize(data.pop("edge_embd"))
angle_embd = NativeLayer.deserialize(data.pop("angle_embd"))
layers = [RepFlowLayer.deserialize(dd) for dd in data.pop("repflow_layers")]
env_mat_edge = EnvMat.deserialize(data.pop("env_mat_edge"))
env_mat_angle = EnvMat.deserialize(data.pop("env_mat_angle"))
variables = data.pop("@variables")
davg = variables["davg"]
dstd = variables["dstd"]
obj = cls(**data)
obj.edge_embd = edge_embd
obj.angle_embd = angle_embd
obj.layers = layers
obj.env_mat_edge = env_mat_edge
obj.env_mat_angle = env_mat_angle
obj.mean = davg
obj.stddev = dstd
return obj

def serialize(self):
"""Serialize the descriptor block."""
return {
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
"e_sel": self.e_sel,
"a_rcut": self.a_rcut,
"a_rcut_smth": self.a_rcut_smth,
"a_sel": self.a_sel,
"ntypes": self.ntypes,
"nlayers": self.nlayers,
"n_dim": self.n_dim,
"e_dim": self.e_dim,
"a_dim": self.a_dim,
"a_compress_rate": self.a_compress_rate,
"a_compress_e_rate": self.a_compress_e_rate,
"a_compress_use_split": self.a_compress_use_split,
"n_multi_edge_message": self.n_multi_edge_message,
"axis_neuron": self.axis_neuron,
"update_angle": self.update_angle,
"activation_function": self.activation_function,
"update_style": self.update_style,
"update_residual": self.update_residual,
"update_residual_init": self.update_residual_init,
"set_davg_zero": self.set_davg_zero,
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"precision": self.precision,
"fix_stat_std": self.fix_stat_std,
"optim_update": self.optim_update,
# variables
"edge_embd": self.edge_embd.serialize(),
"angle_embd": self.angle_embd.serialize(),
"repflow_layers": [layer.serialize() for layer in self.layers],
"env_mat_edge": self.env_mat_edge.serialize(),
"env_mat_angle": self.env_mat_angle.serialize(),
"@variables": {
"davg": to_numpy_array(self["davg"]),
"dstd": to_numpy_array(self["dstd"]),
},
}


class RepFlowLayer(NativeOP):
def __init__(
Expand Down

0 comments on commit 278f51d

Please sign in to comment.