Skip to content

Commit

Permalink
Merge pull request #829 from necrashter/save-load-docs
Browse files Browse the repository at this point in the history
Document file formats in VarStore::save and load methods
  • Loading branch information
LaurentMazare authored Jan 6, 2024
2 parents a0b8580 + 2dcdac6 commit 9ffb2ca
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ impl VarStore {
///
/// Weight values for all the tensors currently stored in the
/// var-store are saved in the given file.
///
/// If the given path ends with the suffix `.safetensors`, the file will
/// be saved in safetensors format. Otherwise, libtorch C++ module format
/// will be used. Note that saving in pickle format (`.pt` extension) is
/// not supported by the C++ API of Torch.
pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
let variables = self.variables_.lock().unwrap();
let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
Expand Down Expand Up @@ -216,6 +221,11 @@ impl VarStore {
/// var-store are loaded from the given file. Note that the set of
/// variables stored in the var-store is not changed, only the values
/// for these tensors are modified.
///
/// The format of the file is deduced from the file extension:
/// - `.safetensors`: The file is assumed to be in safetensors format.
/// - `.bin` or `.pt`: The file is assumed to be in pickle format.
/// - Otherwise, the file is assumed to be in libtorch C++ module format.
pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
if self.device != Device::Mps {
self.load_internal(path)
Expand Down

0 comments on commit 9ffb2ca

Please sign in to comment.