From 46632f90d5915c4489bde4a15e9ca57b3e18cb0a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 3 Sep 2024 20:33:34 -0400 Subject: [PATCH] fix(pt): fix ValueError when array byte order is not native (#4100) Fix #4099. ## Summary by CodeRabbit - **New Features** - Enhanced tensor data type handling for improved numerical stability and performance in deep learning computations. - Introduced a precision dictionary to ensure input data is processed with the correct precision. - **Bug Fixes** - Improved clarity and robustness in the handling of data types within the model evaluation process. Signed-off-by: Jinzhe Zeng --- deepmd/pt/infer/deep_eval.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 9309af657d..af44eba1df 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -14,6 +14,7 @@ import numpy as np import torch +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT from deepmd.dpmodel.output_def import ( ModelOutputDef, OutputVariableCategory, @@ -54,6 +55,7 @@ from deepmd.pt.utils.env import ( DEVICE, GLOBAL_PT_FLOAT_PRECISION, + RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.utils import ( to_torch_tensor, @@ -380,14 +382,22 @@ def _eval_model( natoms = len(atom_types[0]) coord_input = torch.tensor( - coords.reshape([nframes, natoms, 3]), + coords.reshape([nframes, natoms, 3]).astype( + NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PT_FLOAT_PRECISION]] + ), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE, ) - type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + type_input = torch.tensor( + atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISON_DICT[torch.long]]), + dtype=torch.long, + device=DEVICE, + ) if cells is not None: box_input = torch.tensor( - cells.reshape([nframes, 3, 3]), + cells.reshape([nframes, 3, 3]).astype( + NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PT_FLOAT_PRECISION]] + ), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE, )