diff --git a/src/features.rs b/src/features.rs index 83eec0f7..37b24392 100644 --- a/src/features.rs +++ b/src/features.rs @@ -686,7 +686,7 @@ pub trait CachedBuffers { } #[inline] - fn is_supplied_from_below_module(&self) -> bool { + fn are_cached_buffers_supplied_from_below_module(&self) -> bool { false } } diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index 96b550db..5cdf26ce 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -60,7 +60,7 @@ impl<'dev, Mods> Autograd<'dev, Mods> { S: Shape, Mods: CachedBuffers, { - if self.modules.is_supplied_from_below_module() { + if self.modules.are_cached_buffers_supplied_from_below_module() { return; } let no_grads_pool = unsafe { &mut (*self.grads.get()).no_grads_pool }; @@ -226,30 +226,30 @@ impl<'dev, Mods> GradActions for Autograd<'dev, Mods> { } #[inline] - unsafe fn grad< - 'a, - T: 'static, - D: Device + Alloc + crate::ZeroGrad + 'static, - S: Shape, - >( + unsafe fn grad<'a, T, D, S>( &self, device: &'a D, buf: &Buffer<'a, T, D, S>, - ) -> &Buffer<'a, T, D, S> { + ) -> &Buffer<'a, T, D, S> + where + T: 'static, + D: Device + Alloc + crate::ZeroGrad + 'static, + S: Shape, + { unsafe { (*self.grads.get()).get_ref(device, buf.id()) } } #[inline] - unsafe fn grad_mut< - 'a, - T: 'static, - D: Device + Alloc + crate::ZeroGrad + 'static, - S: Shape, - >( + unsafe fn grad_mut<'a, T, D, S>( &self, device: &'a D, buf: &Buffer<'a, T, D, S>, - ) -> &mut Buffer<'a, T, D, S> { + ) -> &mut Buffer<'a, T, D, S> + where + T: 'static, + D: Device + Alloc + crate::ZeroGrad + 'static, + S: Shape, + { unsafe { (*self.grads.get()).get_mut(device, buf.id()) } } } diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 2fb058a4..60d943bc 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -6,14 +6,14 @@ use crate::{ }; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct ReqGradWrapper { +pub struct ReqGradWrapper<'a, Data, T> { pub requires_grad: bool, pub data: Data, - pub _pd: PhantomData, + pub _pd: PhantomData<&'a T>, } impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { - type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper, T>; + type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>; #[inline] fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { @@ -40,7 +40,7 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { } } -impl HasId for ReqGradWrapper { +impl<'a, Data: HasId, T> HasId for ReqGradWrapper<'a, Data, T> { #[inline] fn id(&self) -> crate::Id { self.data.id() @@ -57,7 +57,7 @@ impl HasId for ReqGradWrapper { } } -impl PtrType for ReqGradWrapper { +impl<'a, Data: PtrType, T: Unit> PtrType for ReqGradWrapper<'a, Data, T> { #[inline] fn size(&self) -> usize { self.data.size() @@ -74,7 +74,7 @@ impl PtrType for ReqGradWrapper { } } -impl ShallowCopy for ReqGradWrapper +impl<'a, Data, T> ShallowCopy for ReqGradWrapper<'a, Data, T> where Data: ShallowCopy, { @@ -87,8 +87,8 @@ where } } -impl, T1, D: Device> ToBase - for ReqGradWrapper +impl<'a, T: Unit, S: Shape, Data: ToBase, T1, D: Device> ToBase + for ReqGradWrapper<'a, Data, T1> { #[inline] fn to_base(self) -> ::Base { @@ -96,7 +96,7 @@ impl, T1, D: Device> ToBase } } -impl ToDim for ReqGradWrapper { +impl<'a, T, Data> ToDim for ReqGradWrapper<'a, Data, T> { type Out = Self; #[inline] diff --git a/src/modules/cached.rs b/src/modules/cached.rs index f1f9112b..7e335be7 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -6,7 +6,7 @@ use core::{ use crate::{ AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, - ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, + ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, }; @@ -535,7 +535,7 @@ mod tests { { let device = CPU::>::new(); // let buf: Buffer = device.retrieve(10, ()); - unsafe { Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &()) }; + let _ = Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &()); }; } diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index a77f08c4..b7a01066 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -499,7 +499,7 @@ impl CachedBuffers for Lazy<'_, Mods, T> { } #[inline] - fn is_supplied_from_below_module(&self) -> bool { + fn are_cached_buffers_supplied_from_below_module(&self) -> bool { true } }