From 882d91df4dad6d942bb5312a702dda6ba88731c9 Mon Sep 17 00:00:00 2001 From: bokutotu Date: Fri, 8 Nov 2024 22:54:59 +0900 Subject: [PATCH] change macros --- zenu-optimizer/src/adam.rs | 23 ++++++++++++++++++----- zenu-optimizer/src/sgd.rs | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/zenu-optimizer/src/adam.rs b/zenu-optimizer/src/adam.rs index be2ccb58..729ca2fc 100644 --- a/zenu-optimizer/src/adam.rs +++ b/zenu-optimizer/src/adam.rs @@ -1,6 +1,7 @@ use std::{cell::RefCell, rc::Rc}; use zenu_autograd::{creator::zeros::zeros_like, Variable}; +use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; use crate::Optimizer; @@ -48,8 +49,8 @@ pub struct Adam { // } // } // } -impl Optimizer for Adam { - fn update(&self, parameters: &[Variable]) { +impl> Optimizer for Adam { + fn update(&self, parameters: &P) { let step = *self.step.borrow(); let step = step + T::one(); *self.step.borrow_mut() = step; @@ -57,10 +58,22 @@ impl Optimizer for Adam { let beta1_t = self.beta1.powf(step); let beta2_t = self.beta2.powf(step); - for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) { - let grad = parameter.get_grad().unwrap(); - let grad = grad.get_data(); + let weights = parameters.weights(); + let biases = parameters.biases(); + let mut parameters = Vec::new(); + for (_, weight) in weights.iter() { + if let Some(grad) = weight.get_grad() { + parameters.push(grad); + } + } + for (_, bias) in biases.iter() { + if let Some(grad) = bias.get_grad() { + parameters.push(grad); + } + } + for ((parameter, m), v) in parameters.iter().zip(&self.m).zip(&self.v) { + let grad = parameter.get_data(); let mut v = v.get_data_mut(); let mut v = v.to_ref_mut(); let mut m = m.get_data_mut(); diff --git a/zenu-optimizer/src/sgd.rs b/zenu-optimizer/src/sgd.rs index 96aa2fff..f1890f2a 100644 --- a/zenu-optimizer/src/sgd.rs +++ b/zenu-optimizer/src/sgd.rs @@ -18,7 +18,7 @@ impl SGD { } } -impl Optimizer for SGD { +impl> Optimizer for SGD { fn update(&self, parameters: &P) { let weights = parameters.weights(); let biases = parameters.biases();