Skip to content

Commit

Permalink
change macros
Browse files Browse the repository at this point in the history
  • Loading branch information
bokutotu committed Nov 8, 2024
1 parent 5c87e69 commit 6ccf9a0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
10 changes: 4 additions & 6 deletions zenu-macros/tests/include_vector_map_model.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use std::collections::HashMap;

use rand_distr::{Distribution, StandardNormal};
use zenu_layer::{
layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d},
Parameters,
};
use zenu_macros::Parameters;
use zenu_matrix::{
use zenu::layer::layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d};
use zenu::macros::Parameters;
use zenu::matrix::{
device::{cpu::Cpu, Device},
num::Num,
};
Expand Down Expand Up @@ -70,6 +67,7 @@ impl<T: Num, D: Device> ConvNet<T, D> {

#[test]
fn vec_map() {
use zenu::layer::Parameters;
let model = ConvNet::<f32, Cpu>::new();
let conv_fileter_0 = model.conv_blocks[0].conv2d.filter.clone();
let conv_bias_0 = model.conv_blocks[0].conv2d.bias.clone();
Expand Down
11 changes: 5 additions & 6 deletions zenu-macros/tests/multi_parameter_struct.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use rand_distr::{Distribution, StandardNormal};
use zenu_layer::{
layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d},
Parameters,
};
use zenu_macros::Parameters;
use zenu_matrix::{
use zenu::layer::layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d};
use zenu::macros::Parameters;
use zenu::matrix::{
device::{cpu::Cpu, Device},
num::Num,
};
Expand Down Expand Up @@ -49,6 +46,7 @@ impl<T: Num, D: Device> ConvNet<T, D> {

#[test]
fn multi_params() {
use zenu::layer::Parameters;
let model = ConvNet::<f32, Cpu>::new();
let conv_fileter = model.conv_block.conv2d.filter.clone();
let conv_bias = model.conv_block.conv2d.bias.clone();
Expand Down Expand Up @@ -88,6 +86,7 @@ fn multi_params() {

#[test]
fn test_load_parameters_convnet() {
use zenu::layer::Parameters;
let model = ConvNet::<f32, Cpu>::new();
let parameters = model.parameters();

Expand Down
10 changes: 9 additions & 1 deletion zenu-macros/tests/small_case.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use zenu::layer::layers::linear::Linear;
use zenu::macros::Parameters as ParametersDerive;
use zenu::matrix::{
device::{cpu::Cpu, Device},
matrix::Matrix,
num::Num,
};
use zenu_test::assert_val_eq;

#[derive(Parameters)]
#[derive(ParametersDerive)]
#[parameters(num = T, device = D)]
pub struct Hoge<T, D>
where
Expand All @@ -13,6 +19,7 @@ where

#[test]
fn small_net() {
use zenu::layer::Parameters;
let hoge = Hoge::<f32, Cpu> {
linear: Linear::new(2, 2, true),
};
Expand Down Expand Up @@ -50,6 +57,7 @@ fn small_net() {

#[test]
fn test_load_parameters() {
use zenu::layer::Parameters;
let base_model = Hoge::<f32, Cpu> {
linear: Linear::new(2, 2, true),
};
Expand Down
2 changes: 1 addition & 1 deletion zenu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use zenu_layer::Parameters;
use zenu_matrix::{device::Device, num::Num};
use zenu_optimizer::Optimizer;

extern crate zenu_macros;
pub extern crate zenu_macros;

pub use zenu_autograd as autograd;
pub use zenu_layer as layer;
Expand Down

0 comments on commit 6ccf9a0

Please sign in to comment.