Skip to content

Commit

Permalink
Add make_static to Guard
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 17, 2024
1 parent d05f601 commit a472d3c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 29 deletions.
9 changes: 9 additions & 0 deletions src/cache/locking/guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ impl<'a, T> Guard<'a, T> {
let Guard { data } = self;
Guard { data: f(data) }
}

#[inline]
pub fn make_static(self) -> Option<Guard<'static, T>> {
match self.data {
CowMutCell::Borrowed(_) => None,
CowMutCell::BorrowedMut(_) => None,
CowMutCell::Owned(data) => Some(Guard::new(CowMutCell::Owned(data))),
}
}
}

impl<'a, T> Deref for Guard<'a, T> {
Expand Down
40 changes: 29 additions & 11 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ impl<'a, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Autograd<'a, Mod
}

impl<'dev, Mods> Autograd<'dev, Mods> {
pub fn register_no_grad_buf<T, D, S>(&self, buf: &Buffer<T, D, S>)
pub fn register_no_grad_buf<'a, T, D, S>(&self, buf: &Buffer<T, D, S>)
where
T: Unit + 'static,
D: Device + IsShapeIndep + 'static,
D::Data<T, S>: ShallowCopy,
D::Data<'a, T, S>: ShallowCopy,
S: Shape,
{
let no_grads_pool = unsafe { &mut (*self.grads.get()).no_grads_pool };
Expand Down Expand Up @@ -121,19 +121,19 @@ impl<'dev, Mods: Setup<NewDev>, NewDev> Setup<NewDev> for Autograd<'dev, Mods> {
}
}

impl<'dev, T, Mods: Retrieve<D, T, S>, D, S: Shape> Retrieve<D, T, S> for Autograd<'dev, Mods>
impl<'dev, 'a, T, Mods: Retrieve<'a, D, T, S>, D, S: Shape> Retrieve<'a, D, T, S> for Autograd<'dev, Mods>
where
T: Unit + 'static,
D: IsShapeIndep + Device + 'static,
D::Data<T, S>: ShallowCopy,
D::Data<'a, T, S>: ShallowCopy,
{
#[inline]
unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<T, D::Base<T, S>>>
parents: &impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<'a, T, D::Base<T, S>>>
where
D: Alloc<T>,
{
Expand All @@ -151,15 +151,32 @@ where
_pd: core::marker::PhantomData,
})
}

#[inline]
fn on_retrieve_finish(&self, retrieved_buf: &Buffer<T, D, S>)
where
fn on_retrieve_finish<const NUM_PARENTS: usize>(
&self,
len: usize,
parents: impl Parents<NUM_PARENTS>,
retrieved_buf: &Buffer<T, D, S>,
) where
D: Alloc<T>,
{
self.register_no_grad_buf(retrieved_buf);

self.modules.on_retrieve_finish(retrieved_buf)
self.modules.on_retrieve_finish(len, parents, retrieved_buf)
}

unsafe fn retrieve_entry<const NUM_PARENTS: usize>(
&'a self,
device: &D,
len: usize,
parents: &impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<'a, T, <D>::Base<T, S>>>
where
S: Shape,
D: Alloc<T>
{
todo!()
}
}

Expand Down Expand Up @@ -293,7 +310,8 @@ mod tests {
buf_any: &'b Box<dyn BoxedShallowCopy>,
_device: &'a D,
) -> Option<&'b Buffer<'a, T, D, S>> {
buf_any.as_any().downcast_ref::<Buffer<T, D, S>>()
todo!()
// buf_any.as_any().downcast_ref::<Buffer<'a, T, D, S>>()
}

#[test]
Expand Down
10 changes: 5 additions & 5 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::marker::PhantomData;

use crate::{
flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData,
flag::AllocFlag, Autograd, HasId, IsBasePtr, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData
};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
Expand All @@ -12,11 +12,11 @@ pub struct ReqGradWrapper<Data, T> {
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<'a, T: Unit, Base: crate::HasId + crate::PtrType> =
type Wrap<'a, T: Unit, Base: IsBasePtr> =
ReqGradWrapper<Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: crate::HasId + crate::PtrType>(
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
Expand All @@ -29,14 +29,14 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
}

#[inline]
fn wrapped_as_base<'a, 'b, T: Unit, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b Self::Wrap<'a, T, Base>,
) -> &'b Base {
Mods::wrapped_as_base(&wrap.data)
}

#[inline]
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
Mods::wrapped_as_base_mut(&mut wrap.data)
Expand Down
25 changes: 19 additions & 6 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl<'a, T, D, Mods, S, T2> OnNewBuffer<'a, T, D, S> for Lazy<'_, Mods, T2>
where
T: Unit + 'static,
D: Device + IsShapeIndep + 'static,
D::Data<T, S>: ShallowCopy,
D::Data<'a, T, S>: ShallowCopy,
Mods: OnNewBuffer<'a, T, D, S>,
S: Shape,
{
Expand Down Expand Up @@ -302,21 +302,21 @@ impl<'a, T, NewMods, SD> AddLayer<NewMods, SD> for Lazy<'a, (), T> {
}
}

impl<T, Mods, D, S, T2> Retrieve<D, T, S> for Lazy<'_, Mods, T2>
impl<'a, T, Mods, D, S, T2> Retrieve<'a, D, T, S> for Lazy<'_, Mods, T2>
where
T: Unit + 'static,
Mods: Retrieve<D, T, S>,
Mods: Retrieve<'a, D, T, S>,
D: IsShapeIndep + 'static,
D::Data<T, S>: ShallowCopy,
D::Data<'a, T, S>: ShallowCopy,
S: Shape,
{
#[inline]
unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
_device: &D,
len: usize,
_parents: impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<T, D::Base<T, S>>>
_parents: &impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<'a, T, D::Base<T, S>>>
where
S: Shape,
D: Alloc<T>,
Expand Down Expand Up @@ -378,6 +378,19 @@ where
// pass down
self.modules.on_retrieve_finish(retrieved_buf)
}

unsafe fn retrieve_entry<const NUM_PARENTS: usize>(
&'a self,
device: &D,
len: usize,
parents: &impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<'a, T, <D>::Base<T, S>>>
where
S: Shape,
D: Alloc<T>
{
todo!()
}
}

impl<T, Mods> Cursor for Lazy<'_, Mods, T> {
Expand Down
14 changes: 7 additions & 7 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{
};

use crate::{
flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData,
flag::AllocFlag, HasId, HostPtr, IsBasePtr, Lazy, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData
};

#[derive(Debug, Default)]
Expand All @@ -17,25 +17,25 @@ pub struct LazyWrapper<Data, T> {
}

impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
type Wrap<T: Unit, Base: HasId + PtrType> = LazyWrapper<Mods::Wrap<T, Base>, T>;
type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper<Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)),
_pd: PhantomData,
}
}

#[inline]
fn wrapped_as_base<T: Unit, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base {
Mods::wrapped_as_base(wrap.maybe_data.data().expect(MISSING_DATA))
}

#[inline]
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
Mods::wrapped_as_base_mut(wrap.maybe_data.data_mut().expect(MISSING_DATA))
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>(
T: crate::Unit + 'static,
D: Device + crate::IsShapeIndep + 'static,
D::Data<'a, T, S>: ShallowCopy,
D::Base<T, S>: ShallowCopy,
S: Shape,
{
// shallow copy sets flag to AllocFlag::Wrapper
Expand All @@ -141,6 +142,26 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>(
// cache.insert(*buf.id(), Box::new(buf));
}

pub(crate) unsafe fn register_buf_copyable2<'a, T, D, S>(
cache: &mut HashMap<UniqueId, Box<dyn crate::BoxedShallowCopy>, impl BuildHasher>,
buf: &Buffer<'a, T, D, S>,
) where
T: crate::Unit + 'static,
D: Device + crate::IsShapeIndep + 'static,
D::Base<T, S>: ShallowCopy,
S: Shape,
{
// shallow copy sets flag to AllocFlag::Wrapper
let wrapped_data = unsafe { buf.base().shallow() };

// let buf: Buffer<T, D, S> = Buffer {
// data: wrapped_data,
// device: None,
// };
todo!()
// cache.insert(*buf.id(), Box::new(buf));
}

#[cfg(feature = "std")]
#[inline]
#[allow(unused)]
Expand Down

0 comments on commit a472d3c

Please sign in to comment.