Skip to content

Commit

Permalink
use 0-D Tensor as buffer shape
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 27, 2024
1 parent e3c1ceb commit 49ba5a5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,33 +77,33 @@ def _string_to_array(s: str) -> List[int]:
# register 'has_message_passing' as buffer(cast to int32 as problems may meets with vector<bool>)
self.register_buffer(
"buffer_has_message_passing",
paddle.to_tensor([self.has_message_passing()], dtype="int32"),
paddle.to_tensor(self.has_message_passing(), dtype="int32"),
)
self.buffer_has_message_passing.name = "buffer_has_message_passing"
# register 'ntypes' as buffer
self.register_buffer(
"buffer_ntypes", paddle.to_tensor([self.ntypes], dtype="int32")
"buffer_ntypes", paddle.to_tensor(self.ntypes, dtype="int32")
)
self.buffer_ntypes.name = "buffer_ntypes"
# register 'rcut' as buffer
self.register_buffer(
"buffer_rcut", paddle.to_tensor([self.rcut], dtype="float64")
"buffer_rcut", paddle.to_tensor(self.rcut, dtype="float64")
)
self.buffer_rcut.name = "buffer_rcut"
# register 'dfparam' as buffer
self.register_buffer(
"buffer_dfparam", paddle.to_tensor([self.get_dim_fparam()], dtype="int32")
"buffer_dfparam", paddle.to_tensor(self.get_dim_fparam(), dtype="int32")
)
self.buffer_dfparam.name = "buffer_dfparam"
# register 'daparam' as buffer
self.register_buffer(
"buffer_daparam", paddle.to_tensor([self.get_dim_aparam()], dtype="int32")
"buffer_daparam", paddle.to_tensor(self.get_dim_aparam(), dtype="int32")
)
self.buffer_daparam.name = "buffer_daparam"
# register 'aparam_nall' as buffer
self.register_buffer(
"buffer_aparam_nall",
paddle.to_tensor([self.is_aparam_nall()], dtype="int32"),
paddle.to_tensor(self.is_aparam_nall(), dtype="int32"),
)
self.buffer_aparam_nall.name = "buffer_aparam_nall"

Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotPD.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
#include "DeepPotPD.h"

#include <cstdint>
#include <stdexcept>
#include <numeric>

#include "common.h"
#include "device.h"
#include "errors.h"

using namespace deepmd;

Expand Down

0 comments on commit 49ba5a5

Please sign in to comment.