Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: FL03 <jo3mccain@icloud.com>
  • Loading branch information
FL03 committed Feb 20, 2024
1 parent 4ba7965 commit 9e6730d
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 41 deletions.
3 changes: 0 additions & 3 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ version.workspace = true
[features]
default = []



[lib]
bench = false
crate-type = ["cdylib", "rlib"]
Expand All @@ -26,7 +24,6 @@ test = true

[dependencies]
anyhow.workspace = true
# daggy = { features = ["serde-1"], version = "0.8" }
lazy_static = "1"
num = "0.4"
petgraph = { features = ["serde-1"], version = "0.6" }
Expand Down
15 changes: 15 additions & 0 deletions core/src/ops/gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,18 @@ pub trait Gradient<T> {
fn grad(&self, args: T) -> Self::Gradient;
}

pub trait Grad<T> {
type Output;

/// Compute the gradient of a function at a given point, with respect to a given variable.
// TODO: Create a macro for generating parameter keys
fn grad(&self, at: T, wrt: &str) -> Self::Output;
}

pub trait Parameter {
type Key;
type Value;

fn key(&self) -> Self::Key;
fn value(&self) -> Self::Value;
}
1 change: 0 additions & 1 deletion macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ test = true
num = "0.4"
proc-macro2 = { features = ["nightly", "span-locations"], version = "1" }
quote = "1"

syn = { features = ["extra-traits", "fold", "full"], version = "2" }

[dev-dependencies]
Expand Down
15 changes: 0 additions & 15 deletions macros/examples/sample.rs

This file was deleted.

2 changes: 1 addition & 1 deletion macros/src/ast/gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Appellation: gradient <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use syn::{Attribute, ItemFn};
use syn::parse::{Parse, ParseStream, Result};
use syn::{Attribute, ItemFn};

pub struct GradientAst {
pub attrs: Vec<Attribute>,
Expand Down
4 changes: 2 additions & 2 deletions macros/src/diff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

pub mod handle;

use crate::ast::partials::{PartialAst, PartialFn};
use handle::expr::handle_expr;
use handle::item::handle_item;
use crate::ast::partials::{PartialAst, PartialFn};
use proc_macro2::TokenStream;
use syn::Ident;

Expand All @@ -24,4 +24,4 @@ fn handle_input(input: &PartialFn, var: &Ident) -> TokenStream {
PartialFn::Expr(inner) => handle_expr(&inner, var),
PartialFn::Item(inner) => handle_item(&inner.clone().into(), var),
}
}
}
18 changes: 8 additions & 10 deletions macros/src/grad/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@
Contrib: FL03 <jo3mccain@icloud.com>
*/


use crate::ast::gradient::GradientAst;
use crate::diff::handle::block::handle_block;
use quote::quote;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{ItemFn, Signature};

pub fn gradient(grad: &GradientAst) -> TokenStream {
let GradientAst { attrs, item } = grad;
let attrs = attrs;
let _attrs = attrs;
let item = item;
let output = quote! {
#(#attrs)*
#item
};
output
handle_item_fn(&item)
}

fn handle_item_fn(item: &ItemFn) -> TokenStream {
Expand All @@ -34,9 +29,12 @@ fn handle_item_fn(item: &ItemFn) -> TokenStream {
}
}

let grad = vars.iter().map(|var| handle_block(&block, &var)).collect::<Vec<_>>();
let grad = vars
.iter()
.map(|var| handle_block(&block, &var))
.collect::<Vec<_>>();

quote! {
[#(#grad)*]
}
}
}
3 changes: 1 addition & 2 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#![feature(proc_macro_span)]
extern crate proc_macro;


pub(crate) mod ast;
pub(crate) mod cmp;
pub(crate) mod diff;
Expand Down Expand Up @@ -64,7 +63,7 @@ pub fn autodiff(input: TokenStream) -> TokenStream {
}

#[proc_macro]
pub fn grad(input: TokenStream) -> TokenStream {
pub fn gradient(input: TokenStream) -> TokenStream {
// Parse the input expression into a syntax tree
let expr = parse_macro_input!(input as Expr);

Expand Down
4 changes: 2 additions & 2 deletions macros/tests/gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn test_grad_addition() {
[(y, 1.0)]
);
let z = 3.0;
let df = grad!(x + y + z);
let df = gradient!(x + y + z);
assert_eq!(
df.into_iter().filter(|(k, _v)| k == &x).collect::<Vec<_>>(),
[(x, 1.0)]
Expand Down Expand Up @@ -50,7 +50,7 @@ fn test_grad_multiply() {
df.into_iter().filter(|(k, _v)| k == &y).collect::<Vec<_>>(),
[(y, 1.0)]
);
let df = grad!(x * y + 3.0);
let df = gradient!(x * y + 3.0);
assert_eq!(
df.into_iter().filter(|(k, _v)| k == &x).collect::<Vec<_>>(),
[(x, 2.0)]
Expand Down
4 changes: 2 additions & 2 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//!
//!
#![feature(array_chunks)]

extern crate acme_core as acme;
pub use self::{specs::*, tensor::*};

pub(crate) mod specs;
Expand All @@ -16,8 +18,6 @@ pub mod ops;
pub mod shape;
pub mod store;

pub(crate) use acme_core as core;

pub mod prelude {
pub use crate::specs::*;

Expand Down
2 changes: 1 addition & 1 deletion tensor/src/ops/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Appellation: backprop <mod>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::core::prelude::Ops;
use acme::prelude::Ops;

pub struct BackpropOp(Option<Ops>);
2 changes: 1 addition & 1 deletion tensor/src/specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Appellation: specs <mod>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::core::cmp::id::AtomicId;
use crate::shape::{Rank, Shape};
use crate::store::Layout;
use acme::cmp::id::AtomicId;

pub trait Affine<T> {
type Output;
Expand Down
2 changes: 1 addition & 1 deletion tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Appellation: tensor <mod>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::core::cmp::id::AtomicId;
use crate::data::Scalar;
use crate::shape::{IntoShape, Rank, Shape};
use crate::store::Layout;
use acme::cmp::id::AtomicId;
// use std::ops::{Index, IndexMut};
// use std::sync::{Arc, RwLock};

Expand Down

0 comments on commit 9e6730d

Please sign in to comment.