Skip to content

Commit

Permalink
fix(cc,pt): translate PT exceptions to the DeePMD-kit exception (#3918)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced error handling with the new `translate_error` method,
improving the user experience by providing clearer error messages during
model initialization and computation.
  
- **Tests**
- Added a test to verify exception handling when initializing `DeepPot`
with a non-existent file.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jun 27, 2024
1 parent 17cdcb0 commit cd60d5f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
6 changes: 6 additions & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ class DeepPotPT : public DeepPotBase {
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
/**
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
* @param[in] f The function to run.
* @example translate_error([&](){...});
*/
void translate_error(std::function<void()> f);
};

} // namespace deepmd
47 changes: 38 additions & 9 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,34 @@
#ifdef BUILD_PYTORCH
#include "DeepPotPT.h"

#include <torch/csrc/jit/runtime/jit_exception.h>

#include <cstdint>

#include "common.h"
#include "device.h"
#include "errors.h"

using namespace deepmd;

void DeepPotPT::translate_error(std::function<void()> f) {
try {
f();
// it seems that libtorch may throw different types of exceptions which are
// inherbited from different base classes
// https://github.com/pytorch/pytorch/blob/13316a8d4642454012d34da0d742f1ba93fc0667/torch/csrc/jit/runtime/interpreter.cpp#L924-L939
} catch (const c10::Error& e) {
throw deepmd::deepmd_exception("DeePMD-kit PyTorch backend error: " +
std::string(e.what()));
} catch (const torch::jit::JITException& e) {
throw deepmd::deepmd_exception("DeePMD-kit PyTorch backend JIT error: " +
std::string(e.what()));
} catch (const std::runtime_error& e) {
throw deepmd::deepmd_exception("DeePMD-kit PyTorch backend error: " +
std::string(e.what()));
}
}

torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;

Expand All @@ -26,7 +47,7 @@ DeepPotPT::DeepPotPT(const std::string& model,
const std::string& file_content)
: inited(false) {
try {
init(model, gpu_rank, file_content);
translate_error([&] { init(model, gpu_rank, file_content); });
} catch (...) {
// Clean up and rethrow, as the destructor will not be called
throw;
Expand Down Expand Up @@ -444,8 +465,10 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
translate_error([&] {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
});
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -457,8 +480,10 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
translate_error([&] {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
});
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<double>& force,
Expand All @@ -473,8 +498,10 @@ void DeepPotPT::computew(std::vector<double>& ener,
const int& ago,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
nghost, inlist, ago, fparam, aparam);
translate_error([&] {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
nghost, inlist, ago, fparam, aparam);
});
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -489,8 +516,10 @@ void DeepPotPT::computew(std::vector<double>& ener,
const int& ago,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
nghost, inlist, ago, fparam, aparam);
translate_error([&] {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
nghost, inlist, ago, fparam, aparam);
});
}
void DeepPotPT::computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
Expand Down
4 changes: 4 additions & 0 deletions source/api_cc/tests/test_deepmd_exception.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ TEST(TestDeepmdException, deepmdexception_nofile_deeppot) {
ASSERT_THROW(deepmd::DeepPot("_no_such_file.pb"), deepmd::deepmd_exception);
}

TEST(TestDeepmdException, deepmdexception_nofile_deeppot_pt) {
ASSERT_THROW(deepmd::DeepPot("_no_such_file.pth"), deepmd::deepmd_exception);
}

TEST(TestDeepmdException, deepmdexception_nofile_deeppotmodeldevi) {
ASSERT_THROW(
deepmd::DeepPotModelDevi({"_no_such_file.pb", "_no_such_file.pb"}),
Expand Down

0 comments on commit cd60d5f

Please sign in to comment.