Skip to content

Commit

Permalink
fix(pt): fix ValueError when array byte order is not native (#4100)
Browse files Browse the repository at this point in the history
Fix #4099.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced tensor data type handling for improved numerical stability
and performance in deep learning computations.
- Introduced a precision dictionary to ensure input data is processed
with the correct precision.

- **Bug Fixes**
- Improved clarity and robustness in the handling of data types within
the model evaluation process.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Sep 4, 2024
1 parent 1abb89b commit 46632f9
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
from deepmd.dpmodel.output_def import (
ModelOutputDef,
OutputVariableCategory,
Expand Down Expand Up @@ -54,6 +55,7 @@
from deepmd.pt.utils.env import (
DEVICE,
GLOBAL_PT_FLOAT_PRECISION,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
Expand Down Expand Up @@ -380,14 +382,22 @@ def _eval_model(
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([nframes, natoms, 3]),
coords.reshape([nframes, natoms, 3]).astype(
NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PT_FLOAT_PRECISION]]
),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
type_input = torch.tensor(
atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISON_DICT[torch.long]]),
dtype=torch.long,
device=DEVICE,
)
if cells is not None:
box_input = torch.tensor(
cells.reshape([nframes, 3, 3]),
cells.reshape([nframes, 3, 3]).astype(
NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PT_FLOAT_PRECISION]]
),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
Expand Down

0 comments on commit 46632f9

Please sign in to comment.