Resilient Propagation for tch-rs ported from PyTorch, specifically torch.optim.Rprop.
⚠️ Currently only tested with simple models!
Licensed under the same terms as PyTorch, see LICENSE
Add to Cargo.toml
rprop-tch = { git = "https://github.com/offdroid/rprop-tch-rs.git" }
Usage matches tch::nn::Optimizer
let vs = tch::nn::VarStore::new(tch::Device::Cpu);
// Init model with `vs`
let net: &dyn tch::nn::Module = todo!();
// Build Rprop optimizer, here with default paramters
let mut opt = rprop_tch::Rprop::build_default(&vs, Some(0.01));
// Training loop
for epoch in 1..10 {
let (x, y) = todo!();
let loss: tch::Tensor = net.forward(&x).mse_loss(&y);
// Use it like tch::nn::Optimizer
opt.zero_grad();
loss.backward();
opt.step();
}
Check examples and/or run
cargo run --example basic