Skip to content

Commit

Permalink
change macros
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 8, 2024
1 parent c25714a commit 882d91d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
23 changes: 18 additions & 5 deletions zenu-optimizer/src/adam.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{cell::RefCell, rc::Rc};

use zenu_autograd::{creator::zeros::zeros_like, Variable};
use zenu_layer::Parameters;
use zenu_matrix::{device::Device, num::Num};

use crate::Optimizer;
Expand Down Expand Up @@ -48,19 +49,31 @@ pub struct Adam<T: Num, D: Device> {
// }
// }
// }
impl<T: Num, D: Device> Optimizer<T, D> for Adam<T, D> {
fn update(&self, parameters: &[Variable<T, D>]) {
impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for Adam<T, D> {
fn update(&self, parameters: &P) {
let step = *self.step.borrow();
let step = step + T::one();
*self.step.borrow_mut() = step;

let beta1_t = self.beta1.powf(step);
let beta2_t = self.beta2.powf(step);

for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) {
let grad = parameter.get_grad().unwrap();
let grad = grad.get_data();
let weights = parameters.weights();
let biases = parameters.biases();
let mut parameters = Vec::new();
for (_, weight) in weights.iter() {
if let Some(grad) = weight.get_grad() {
parameters.push(grad);
}
}
for (_, bias) in biases.iter() {
if let Some(grad) = bias.get_grad() {
parameters.push(grad);
}
}

for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) {
let grad = parameter.get_data();
let mut v = v.get_data_mut();
let mut v = v.to_ref_mut();
let mut m = m.get_data_mut();
Expand Down
2 changes: 1 addition & 1 deletion zenu-optimizer/src/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<T: Num, D: Device> SGD<T, D> {
}
}

impl<T: Num, D: Device, P: Parameters> Optimizer<T, D, P> for SGD<T, D> {
impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for SGD<T, D> {
fn update(&self, parameters: &P) {
let weights = parameters.weights();
let biases = parameters.biases();
Expand Down

0 comments on commit 882d91d

Please sign in to comment.