Skip to content

Commit

Permalink
Add 'a to ReqGradWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 17, 2024
1 parent 141f217 commit 61069f1
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ pub trait CachedBuffers {
}

#[inline]
fn is_supplied_from_below_module(&self) -> bool {
fn are_cached_buffers_supplied_from_below_module(&self) -> bool {
false
}
}
Expand Down
30 changes: 15 additions & 15 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'dev, Mods> Autograd<'dev, Mods> {
S: Shape,
Mods: CachedBuffers,
{
if self.modules.is_supplied_from_below_module() {
if self.modules.are_cached_buffers_supplied_from_below_module() {
return;
}
let no_grads_pool = unsafe { &mut (*self.grads.get()).no_grads_pool };
Expand Down Expand Up @@ -226,30 +226,30 @@ impl<'dev, Mods> GradActions for Autograd<'dev, Mods> {
}

#[inline]
unsafe fn grad<
'a,
T: 'static,
D: Device + Alloc<T> + crate::ZeroGrad<T> + 'static,
S: Shape,
>(
unsafe fn grad<'a, T, D, S>(
&self,
device: &'a D,
buf: &Buffer<'a, T, D, S>,
) -> &Buffer<'a, T, D, S> {
) -> &Buffer<'a, T, D, S>
where
T: 'static,
D: Device + Alloc<T> + crate::ZeroGrad<T> + 'static,
S: Shape,
{
unsafe { (*self.grads.get()).get_ref(device, buf.id()) }
}

#[inline]
unsafe fn grad_mut<
'a,
T: 'static,
D: Device + Alloc<T> + crate::ZeroGrad<T> + 'static,
S: Shape,
>(
unsafe fn grad_mut<'a, T, D, S>(
&self,
device: &'a D,
buf: &Buffer<'a, T, D, S>,
) -> &mut Buffer<'a, T, D, S> {
) -> &mut Buffer<'a, T, D, S>
where
T: 'static,
D: Device + Alloc<T> + crate::ZeroGrad<T> + 'static,
S: Shape,
{
unsafe { (*self.grads.get()).get_mut(device, buf.id()) }
}
}
Expand Down
18 changes: 9 additions & 9 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::{
};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ReqGradWrapper<Data, T> {
pub struct ReqGradWrapper<'a, Data, T> {
pub requires_grad: bool,
pub data: Data,
pub _pd: PhantomData<T>,
pub _pd: PhantomData<&'a T>,
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<Mods::Wrap<'a, T, Base>, T>;
type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
Expand All @@ -40,7 +40,7 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
}
}

impl<Data: HasId, T> HasId for ReqGradWrapper<Data, T> {
impl<'a, Data: HasId, T> HasId for ReqGradWrapper<'a, Data, T> {
#[inline]
fn id(&self) -> crate::Id {
self.data.id()
Expand All @@ -57,7 +57,7 @@ impl<Data: HasId, T> HasId for ReqGradWrapper<Data, T> {
}
}

impl<Data: PtrType, T: Unit> PtrType for ReqGradWrapper<Data, T> {
impl<'a, Data: PtrType, T: Unit> PtrType for ReqGradWrapper<'a, Data, T> {
#[inline]
fn size(&self) -> usize {
self.data.size()
Expand All @@ -74,7 +74,7 @@ impl<Data: PtrType, T: Unit> PtrType for ReqGradWrapper<Data, T> {
}
}

impl<Data, T> ShallowCopy for ReqGradWrapper<Data, T>
impl<'a, Data, T> ShallowCopy for ReqGradWrapper<'a, Data, T>
where
Data: ShallowCopy,
{
Expand All @@ -87,16 +87,16 @@ where
}
}

impl<T: Unit, S: Shape, Data: ToBase<T, D, S>, T1, D: Device> ToBase<T, D, S>
for ReqGradWrapper<Data, T1>
impl<'a, T: Unit, S: Shape, Data: ToBase<T, D, S>, T1, D: Device> ToBase<T, D, S>
for ReqGradWrapper<'a, Data, T1>
{
#[inline]
fn to_base(self) -> <D as Device>::Base<T, S> {
self.data.to_base()
}
}

impl<T, Data> ToDim for ReqGradWrapper<Data, T> {
impl<'a, T, Data> ToDim for ReqGradWrapper<'a, Data, T> {
type Out = Self;

#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::{

use crate::{
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device,
ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module,
ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module,
OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule,
SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData,
};
Expand Down Expand Up @@ -535,7 +535,7 @@ mod tests {
{
let device = CPU::<Cached<Base>>::new();
// let buf: Buffer<f32, _> = device.retrieve(10, ());
unsafe { Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &()) };
let _ = Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &());
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ impl<T, Mods> CachedBuffers for Lazy<'_, Mods, T> {
}

#[inline]
fn is_supplied_from_below_module(&self) -> bool {
fn are_cached_buffers_supplied_from_below_module(&self) -> bool {
true
}
}
Expand Down

0 comments on commit 61069f1

Please sign in to comment.