From 6ccf9a0147f6eb7ceb10a83c5a8b26f89689f725 Mon Sep 17 00:00:00 2001 From: bokutotu Date: Fri, 8 Nov 2024 21:32:29 +0900 Subject: [PATCH] change macros --- zenu-macros/tests/include_vector_map_model.rs | 10 ++++------ zenu-macros/tests/multi_parameter_struct.rs | 11 +++++------ zenu-macros/tests/small_case.rs | 10 +++++++++- zenu/src/lib.rs | 2 +- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/zenu-macros/tests/include_vector_map_model.rs b/zenu-macros/tests/include_vector_map_model.rs index d805ecf6..ae5d4199 100644 --- a/zenu-macros/tests/include_vector_map_model.rs +++ b/zenu-macros/tests/include_vector_map_model.rs @@ -1,12 +1,9 @@ use std::collections::HashMap; use rand_distr::{Distribution, StandardNormal}; -use zenu_layer::{ - layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d}, - Parameters, -}; -use zenu_macros::Parameters; -use zenu_matrix::{ +use zenu::layer::layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d}; +use zenu::macros::Parameters; +use zenu::matrix::{ device::{cpu::Cpu, Device}, num::Num, }; @@ -70,6 +67,7 @@ impl ConvNet { #[test] fn vec_map() { + use zenu::layer::Parameters; let model = ConvNet::::new(); let conv_fileter_0 = model.conv_blocks[0].conv2d.filter.clone(); let conv_bias_0 = model.conv_blocks[0].conv2d.bias.clone(); diff --git a/zenu-macros/tests/multi_parameter_struct.rs b/zenu-macros/tests/multi_parameter_struct.rs index 9e51eb45..7ff23886 100644 --- a/zenu-macros/tests/multi_parameter_struct.rs +++ b/zenu-macros/tests/multi_parameter_struct.rs @@ -1,10 +1,7 @@ use rand_distr::{Distribution, StandardNormal}; -use zenu_layer::{ - layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d}, - Parameters, -}; -use zenu_macros::Parameters; -use zenu_matrix::{ +use zenu::layer::layers::{conv2d::Conv2d, linear::Linear, max_pool_2d::MaxPool2d}; +use zenu::macros::Parameters; +use zenu::matrix::{ device::{cpu::Cpu, Device}, num::Num, }; @@ -49,6 +46,7 @@ impl ConvNet { #[test] fn multi_params() { + use zenu::layer::Parameters; let model = ConvNet::::new(); let conv_fileter = model.conv_block.conv2d.filter.clone(); let conv_bias = model.conv_block.conv2d.bias.clone(); @@ -88,6 +86,7 @@ fn multi_params() { #[test] fn test_load_parameters_convnet() { + use zenu::layer::Parameters; let model = ConvNet::::new(); let parameters = model.parameters(); diff --git a/zenu-macros/tests/small_case.rs b/zenu-macros/tests/small_case.rs index fb2017db..32a9d3e2 100644 --- a/zenu-macros/tests/small_case.rs +++ b/zenu-macros/tests/small_case.rs @@ -1,7 +1,13 @@ use zenu::layer::layers::linear::Linear; +use zenu::macros::Parameters as ParametersDerive; +use zenu::matrix::{ + device::{cpu::Cpu, Device}, + matrix::Matrix, + num::Num, +}; use zenu_test::assert_val_eq; -#[derive(Parameters)] +#[derive(ParametersDerive)] #[parameters(num = T, device = D)] pub struct Hoge where @@ -13,6 +19,7 @@ where #[test] fn small_net() { + use zenu::layer::Parameters; let hoge = Hoge:: { linear: Linear::new(2, 2, true), }; @@ -50,6 +57,7 @@ fn small_net() { #[test] fn test_load_parameters() { + use zenu::layer::Parameters; let base_model = Hoge:: { linear: Linear::new(2, 2, true), }; diff --git a/zenu/src/lib.rs b/zenu/src/lib.rs index bf1fb990..42f098ea 100644 --- a/zenu/src/lib.rs +++ b/zenu/src/lib.rs @@ -15,7 +15,7 @@ use zenu_layer::Parameters; use zenu_matrix::{device::Device, num::Num}; use zenu_optimizer::Optimizer; -extern crate zenu_macros; +pub extern crate zenu_macros; pub use zenu_autograd as autograd; pub use zenu_layer as layer;