From 9ebd95789d8640cf63c59dc42b29762e1cc39452 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:15:11 +0100 Subject: [PATCH] Update LockedMap --- examples/custom_module.rs | 4 +- src/buffer/num.rs | 16 +++-- src/cache/locking.rs | 4 +- src/cache/locking/guard.rs | 10 +-- src/cache/locking/locked_array.rs | 11 ++- src/cache/locking/locked_map.rs | 21 ++++-- src/cache/owned_cache.rs | 4 +- src/cache/owned_cache/fast_cache.rs | 40 ++++++----- src/cache/owned_cache/length_cache.rs | 4 +- src/cow_mut.rs | 2 +- src/devices.rs | 5 +- src/devices/cpu/ops.rs | 7 +- src/devices/cuda/ops.rs | 3 +- src/devices/opencl/ops.rs | 10 ++- src/devices/stack/stack_device.rs | 3 +- src/devices/stack_array.rs | 1 - src/devices/untyped/ops.rs | 3 +- src/devices/vulkan/ops.rs | 3 +- src/devices/wgsl/ops.rs | 3 +- src/features.rs | 14 ++-- src/lib.rs | 4 +- src/modules/autograd/wrapper.rs | 7 +- src/modules/base.rs | 17 +++-- src/modules/cached.rs | 99 ++++++++++++++++++++------- src/modules/fork.rs | 4 +- src/modules/graph.rs | 4 +- src/modules/lazy/wrapper.rs | 6 +- src/wrapper.rs | 10 ++- 28 files changed, 215 insertions(+), 104 deletions(-) diff --git a/examples/custom_module.rs b/examples/custom_module.rs index 694abce0..91117f86 100644 --- a/examples/custom_module.rs +++ b/examples/custom_module.rs @@ -49,7 +49,9 @@ impl WrappedData for CustomModule { } #[inline] - fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { + fn wrapped_as_base_mut( + wrap: &mut Self::Wrap, + ) -> &mut Base { Mods::wrapped_as_base_mut(wrap) } } diff --git a/src/buffer/num.rs b/src/buffer/num.rs index 16809f77..f2b5ab2f 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -4,7 +4,8 @@ use core::{ }; use crate::{ - flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, IsBasePtr, OnDropBuffer, PtrType, ShallowCopy, Unit, WrappedData + flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, IsBasePtr, OnDropBuffer, PtrType, + ShallowCopy, Unit, WrappedData, }; #[derive(Debug, Default)] @@ -60,7 +61,10 @@ impl Device for () { } #[inline(always)] - fn base_to_data<'a, T: Unit, S: crate::Shape>(&self, base: Self::Base) -> Self::Data<'a, T, S> { + fn base_to_data<'a, T: Unit, S: crate::Shape>( + &self, + base: Self::Base, + ) -> Self::Data<'a, T, S> { base } @@ -114,12 +118,16 @@ impl WrappedData for () { } #[inline] - fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base { + fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base { wrap } #[inline] - fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base { + fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b mut Self::Wrap<'a, T, Base>, + ) -> &'b mut Base { wrap } } diff --git a/src/cache/locking.rs b/src/cache/locking.rs index 533bd2ca..e8c85b20 100644 --- a/src/cache/locking.rs +++ b/src/cache/locking.rs @@ -1,6 +1,6 @@ +mod guard; mod locked_array; mod locked_map; -mod guard; pub use guard::*; pub use locked_array::*; pub use locked_map::*; @@ -11,4 +11,4 @@ pub type State = Result; pub enum LockInfo { Locked, None, -} \ No newline at end of file +} diff --git a/src/cache/locking/guard.rs b/src/cache/locking/guard.rs index 0c878912..ae0a3bd7 100644 --- a/src/cache/locking/guard.rs +++ b/src/cache/locking/guard.rs @@ -1,4 +1,7 @@ -use core::{mem::ManuallyDrop, ops::{Deref, DerefMut}}; +use core::{ + mem::ManuallyDrop, + ops::{Deref, DerefMut}, +}; use crate::{CowMutCell, HasId, PtrType, ShallowCopy}; @@ -18,9 +21,7 @@ impl<'a, T> Guard<'a, T> { F: FnOnce(CowMutCell<'a, T>) -> CowMutCell<'a, U>, { let Guard { data } = self; - Guard { - data: f(data), - } + Guard { data: f(data) } } } @@ -69,4 +70,3 @@ impl<'a, T> ShallowCopy for Guard<'a, T> { todo!() } } - diff --git a/src/cache/locking/locked_array.rs b/src/cache/locking/locked_array.rs index 579bb11b..fdb3c144 100644 --- a/src/cache/locking/locked_array.rs +++ b/src/cache/locking/locked_array.rs @@ -37,10 +37,9 @@ impl LockedArray { if data.is_none() { return State::Err(LockInfo::None); } - return State::Ok(Guard::new(CowMutCell::Borrowed(Ref::map( - data, - |data| data.as_ref().unwrap(), - )))); + return State::Ok(Guard::new(CowMutCell::Borrowed(Ref::map(data, |data| { + data.as_ref().unwrap() + })))); } Err(_) => return State::Err(LockInfo::Locked), } @@ -84,7 +83,7 @@ mod tests { data1.push(2); assert_eq!(data1.as_slice(), [1, 2]); } - + #[cfg(feature = "std")] #[test] #[should_panic] @@ -93,7 +92,7 @@ mod tests { locked_array.set(1, vec![10]); locked_array.set(1, vec![10]); } - + #[cfg(feature = "std")] #[test] fn test_get_not_set() { diff --git a/src/cache/locking/locked_map.rs b/src/cache/locking/locked_map.rs index 2e75a715..395e80ab 100644 --- a/src/cache/locking/locked_map.rs +++ b/src/cache/locking/locked_map.rs @@ -1,3 +1,4 @@ +use core::ops::Deref; use std::{ cell::{Ref, RefCell, RefMut, UnsafeCell}, collections::HashMap, @@ -7,7 +8,7 @@ use std::{ use crate::{LockInfo, State}; pub struct LockedMap { - data: UnsafeCell>, S>>, + data: RefCell>, S>>, } impl Default for LockedMap { @@ -27,11 +28,22 @@ impl LockedMap { } } impl LockedMap { + #[inline] + pub fn len(&self) -> usize { + self.data.borrow().len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.data.borrow().is_empty() + } + pub fn insert(&self, id: K, data: T) where K: Eq + Hash, { - let map = unsafe { &mut *self.data.get() }; + // let map = unsafe { &mut *self.data.get() }; + let mut map = self.data.borrow_mut(); if map.contains_key(&id) { panic!() } @@ -42,7 +54,7 @@ impl LockedMap { where K: Eq + Hash, { - let map = unsafe { &mut *self.data.get() }; + let map = unsafe { &*self.data.as_ptr() }; let entry = map.get(id).ok_or(LockInfo::None)?; (&**entry).try_borrow().map_err(|_| LockInfo::Locked) } @@ -51,7 +63,7 @@ impl LockedMap { where K: Eq + Hash, { - let map = unsafe { &mut *self.data.get() }; + let map = unsafe { &*self.data.as_ptr() }; let entry = map.get(id).ok_or(LockInfo::None)?; (&**entry).try_borrow_mut().map_err(|_| LockInfo::Locked) } @@ -67,7 +79,6 @@ mod tests { #[test] fn test_locked_boxed() { - let locked_map = LockedMap::, BuildHasherDefault>::new(); locked_map.insert(0, vec![1, 2, 3, 4]); diff --git a/src/cache/owned_cache.rs b/src/cache/owned_cache.rs index 7a8e6649..ead8c03e 100644 --- a/src/cache/owned_cache.rs +++ b/src/cache/owned_cache.rs @@ -16,7 +16,7 @@ pub trait Cache { ) -> D::Base where T: Unit, - D: Alloc + 'static, - D::Base: ShallowCopy + 'static, + D: Alloc, + D::Base: ShallowCopy, S: Shape; } diff --git a/src/cache/owned_cache/fast_cache.rs b/src/cache/owned_cache/fast_cache.rs index d29c9e3d..2bfe555b 100644 --- a/src/cache/owned_cache/fast_cache.rs +++ b/src/cache/owned_cache/fast_cache.rs @@ -2,12 +2,11 @@ use core::{any::Any, hash::BuildHasherDefault}; use std::{collections::HashMap, sync::Arc}; use crate::{ - flag::AllocFlag, Alloc, Cache, Device, NoHasher, PtrType, ShallowCopy, Shape, UniqueId, Unit, + flag::AllocFlag, Alloc, Cache, Device, LockedMap, NoHasher, PtrType, ShallowCopy, Shape, UniqueId, Unit }; -#[derive(Clone)] pub struct FastCache { - pub nodes: HashMap, BuildHasherDefault>, + pub nodes: LockedMap, BuildHasherDefault>, } impl Default for FastCache { @@ -28,11 +27,11 @@ impl Cache for FastCache { ) -> D::Base where T: Unit, - D: Alloc + 'static, + D: Alloc, D::Base: ShallowCopy + 'static, S: Shape, { - self.get(device, id, len, new_buf_callback) + self.get(device, id, len, new_buf_callback).unwrap() } } @@ -53,16 +52,16 @@ impl FastCache { id: UniqueId, len: usize, new_buf_callback: impl FnMut(UniqueId, &D::Base), - ) -> D::Base + ) -> crate::Result> where T: Unit, - D: Alloc + 'static, + D: Alloc, D::Base: ShallowCopy + 'static, S: Shape, { let maybe_allocated = self.nodes.get(&id); match maybe_allocated { - Some(data) => { + Ok(data) => { let data = unsafe { data.downcast_ref::>() .expect("Invalid request for data type!") @@ -71,32 +70,39 @@ impl FastCache { // TODO: not necessary, could add length to hashmap assert_eq!(data.size(), len, "Data size mismatch! Did you use e.g. if conditions in a (cursor) loop retrieving buffers with a different size?"); - data + Ok(data) } - None => unsafe { self.add_node(device, id, len, new_buf_callback) }, + Err(e) => { + match e { + crate::LockInfo::Locked => panic!("should return error"), + crate::LockInfo::None => { + unsafe { self.add_node(device, id, len, new_buf_callback) } + }, + } + }, } } unsafe fn add_node( - &mut self, + &self, device: &D, id: UniqueId, len: usize, mut callback: impl FnMut(UniqueId, &D::Base), - ) -> ::Base + ) -> crate::Result<::Base> where T: Unit, D: Alloc, D::Base: ShallowCopy + 'static, S: Shape, { - let data = device.alloc::(len, AllocFlag::None).unwrap(); + let data = device.alloc::(len, AllocFlag::None)?; let shallow_data = unsafe { data.shallow() }; callback(id, &shallow_data); self.nodes.insert(id, Arc::new(data)); - shallow_data + Ok(shallow_data) } } @@ -115,12 +121,12 @@ mod tests { assert_eq!(cache.nodes.len(), 0); - let out = unsafe { cache.add_node::(&device, 0, 10, |_a, _b| ()) }; + let out = unsafe { cache.add_node::(&device, 0, 10, |_a, _b| ()) }.unwrap(); assert_eq!(cache.nodes.len(), 1); assert_eq!(out.len, 10); - let out1 = unsafe { cache.get::(&device, 1, 10, |_a, _b| ()) }; + let out1 = unsafe { cache.get::(&device, 1, 10, |_a, _b| ()) }.unwrap(); assert_ne!(out.ptr, out1.ptr); } @@ -135,7 +141,7 @@ mod tests { let mut prev = None; for _ in device.range(0..1000) { - let out3 = unsafe { cache.get::(&device, 0, 10, |_a, _b| ()) }; + let out3 = unsafe { cache.get::(&device, 0, 10, |_a, _b| ()) }.unwrap(); if prev.is_none() { prev = Some(out3.ptr); } diff --git a/src/cache/owned_cache/length_cache.rs b/src/cache/owned_cache/length_cache.rs index 682c6d64..01c66814 100644 --- a/src/cache/owned_cache/length_cache.rs +++ b/src/cache/owned_cache/length_cache.rs @@ -26,7 +26,7 @@ impl Cache for LengthCache { ) -> D::Base where T: Unit, - D: Alloc + 'static, + D: Alloc, D::Base: ShallowCopy + 'static, S: Shape, { @@ -54,7 +54,7 @@ impl LengthCache { ) -> D::Base where T: Unit, - D: Alloc + 'static, + D: Alloc, D::Base: ShallowCopy + 'static, S: Shape, { diff --git a/src/cow_mut.rs b/src/cow_mut.rs index 36ec3308..83b34a63 100644 --- a/src/cow_mut.rs +++ b/src/cow_mut.rs @@ -1,5 +1,5 @@ -use core::ops::{Deref, DerefMut}; use core::cell::{Ref, RefMut}; +use core::ops::{Deref, DerefMut}; pub type CowMutCell<'a, T> = CowMut, Ref<'a, T>>; pub type CowMutRef<'a, T> = CowMut; diff --git a/src/devices.rs b/src/devices.rs index 3d5f3fd5..04083c9e 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -63,8 +63,9 @@ pub trait Device: OnDropBuffer + Sized { &self, wrap: Self::Wrap<'a, T, Self::Base>, ) -> Self::Data<'a, T, S>; - fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(data: &'b Self::Data<'a, T, S>) - -> &'b Self::Wrap<'a, T, Self::Base>; + fn data_as_wrap<'a, 'b, T: Unit, S: Shape>( + data: &'b Self::Data<'a, T, S>, + ) -> &'b Self::Wrap<'a, T, Self::Base>; fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>( data: &'b mut Self::Data<'a, T, S>, ) -> &'b mut Self::Wrap<'a, T, Self::Base>; diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index b0ff08c4..20337c25 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -76,7 +76,12 @@ where D::Base: Deref, S: Shape, { - type Read<'a> = &'a [T] where T: 'a, D: 'a, S: 'a; + type Read<'a> + = &'a [T] + where + T: 'a, + D: 'a, + S: 'a; #[inline] fn read<'a>(&self, buf: &'a D::Base) -> Self::Read<'a> diff --git a/src/devices/cuda/ops.rs b/src/devices/cuda/ops.rs index bf6431f8..e545b81d 100644 --- a/src/devices/cuda/ops.rs +++ b/src/devices/cuda/ops.rs @@ -18,7 +18,8 @@ pass_down_add_operation!(CUDA); pass_down_exec_now!(CUDA); impl Read for CUDA { - type Read<'a> = Vec + type Read<'a> + = Vec where T: 'a, CUDA: 'a; diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index 3e374159..5e0d203d 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -175,9 +175,15 @@ where S: Shape, { #[cfg(not(unified_cl))] - type Read<'a> = Vec where T: 'a; + type Read<'a> + = Vec + where + T: 'a; #[cfg(unified_cl)] - type Read<'a> = &'a [T] where T: 'a; + type Read<'a> + = &'a [T] + where + T: 'a; #[cfg(not(unified_cl))] #[inline] diff --git a/src/devices/stack/stack_device.rs b/src/devices/stack/stack_device.rs index 2b3a5bd4..7a1c10cc 100644 --- a/src/devices/stack/stack_device.rs +++ b/src/devices/stack/stack_device.rs @@ -102,7 +102,8 @@ impl Read for Stack where S::ARR: Copy, { - type Read<'a> = S::ARR + type Read<'a> + = S::ARR where T: 'a, Stack: 'a, diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index 4a4a6597..72274aa7 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -137,7 +137,6 @@ impl HostPtr for StackArray { } } - impl WrappedCopy for StackArray { type Base = Self; diff --git a/src/devices/untyped/ops.rs b/src/devices/untyped/ops.rs index 1e6d25a2..a1f8b1fb 100644 --- a/src/devices/untyped/ops.rs +++ b/src/devices/untyped/ops.rs @@ -6,7 +6,8 @@ use crate::{ use super::{untyped_device::Untyped, AsType}; impl Read for Untyped { - type Read<'a> = Vec + type Read<'a> + = Vec where T: 'a, Self: 'a, diff --git a/src/devices/vulkan/ops.rs b/src/devices/vulkan/ops.rs index bd6e6a18..6a825e25 100644 --- a/src/devices/vulkan/ops.rs +++ b/src/devices/vulkan/ops.rs @@ -60,7 +60,8 @@ pub fn try_vk_clear( } impl Read for Vulkan { - type Read<'a> = VkArray + type Read<'a> + = VkArray where T: 'a, Self: 'a, diff --git a/src/devices/wgsl/ops.rs b/src/devices/wgsl/ops.rs index c8da72d4..591d6c91 100644 --- a/src/devices/wgsl/ops.rs +++ b/src/devices/wgsl/ops.rs @@ -6,7 +6,8 @@ use crate::{ use super::{wgsl_device::Wgsl, AsShaderArg, WgslShaderLaunch}; impl, Mods: OnDropBuffer + 'static> Read for Wgsl { - type Read<'a> = D::Read<'a> + type Read<'a> + = D::Read<'a> where T: 'a, D: 'a, diff --git a/src/features.rs b/src/features.rs index 867977c3..7213cd8c 100644 --- a/src/features.rs +++ b/src/features.rs @@ -35,7 +35,7 @@ pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer { ) -> crate::Result>> where S: Shape, - D: Device + Alloc; + D: Alloc; #[track_caller] unsafe fn retrieve( @@ -46,14 +46,16 @@ pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer { ) -> crate::Result>> where S: Shape, - D: Device + Alloc; - - + D: Alloc; // "actor" #[inline] - fn on_retrieve_finish(&self, _len: usize, _parents: impl Parents, _retrieved_buf: &Buffer) - where + fn on_retrieve_finish( + &self, + _len: usize, + _parents: impl Parents, + _retrieved_buf: &Buffer, + ) where D: Alloc, { } diff --git a/src/lib.rs b/src/lib.rs index 617e1b83..c0a88966 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,7 @@ pub mod flag; mod any_op; #[cfg(feature = "std")] mod boxed_shallow_copy; +mod cow_mut; pub mod hooks; mod id; mod layer_management; @@ -113,10 +114,10 @@ mod shape; mod two_way_ops; mod unary; mod wrapper; -mod cow_mut; pub use any_op::*; pub use cache::*; +pub use cow_mut::*; pub use features::*; pub use hooks::*; pub use id::*; @@ -126,7 +127,6 @@ pub use number::*; pub use parents::*; pub use range::*; pub use wrapper::*; -pub use cow_mut::*; #[cfg(not(feature = "cpu"))] pub mod dummy_cpu; diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 522dcad2..b17a02f5 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -1,6 +1,8 @@ use core::marker::PhantomData; -use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData}; +use crate::{ + flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData, +}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct ReqGradWrapper { @@ -10,7 +12,8 @@ pub struct ReqGradWrapper { } impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { - type Wrap<'a, T: Unit, Base: crate::HasId + crate::PtrType> = ReqGradWrapper, T>; + type Wrap<'a, T: Unit, Base: crate::HasId + crate::PtrType> = + ReqGradWrapper, T>; #[inline] fn wrap_in_base<'a, T: Unit, Base: crate::HasId + crate::PtrType>( diff --git a/src/modules/base.rs b/src/modules/base.rs index fa9d5df4..fbe8d52a 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -11,17 +11,24 @@ impl WrappedData for Base { type Wrap<'a, T: Unit, Base: 'static + HasId + PtrType> = Base; #[inline] - fn wrap_in_base<'a, T: Unit, Base: 'static + HasId + PtrType>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base<'a, T: Unit, Base: 'static + HasId + PtrType>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { base } #[inline] - fn wrapped_as_base<'a, 'b, T: Unit, Base: 'static + HasId + PtrType>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base { + fn wrapped_as_base<'a, 'b, T: Unit, Base: 'static + HasId + PtrType>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base { wrap } #[inline] - fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: 'static + HasId + PtrType>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base { + fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: 'static + HasId + PtrType>( + wrap: &'b mut Self::Wrap<'a, T, Base>, + ) -> &'b mut Base { wrap } } @@ -92,7 +99,7 @@ impl<'a, D, T: Unit, S: Shape> Retrieve<'a, D, T, S> for Base { self.retrieve(device, len, parents) } - #[inline] + #[inline] unsafe fn retrieve( &self, device: &D, @@ -101,7 +108,7 @@ impl<'a, D, T: Unit, S: Shape> Retrieve<'a, D, T, S> for Base { ) -> crate::Result::Base>> where S: Shape, - D: Device + Alloc + D: Device + Alloc, { device.alloc(len, AllocFlag::None) } diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 3119a8d0..7cb94095 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -1,10 +1,13 @@ use core::{ - cell::{Cell, RefCell}, + cell::{Cell, RefCell, RefMut}, marker::PhantomData, }; use crate::{ - AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, ExecNow, FastCache, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, UniqueId, Unit, WrappedData + AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, + ExecNow, FastCache, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, + OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, + SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, }; #[cfg(feature = "graph")] @@ -19,21 +22,24 @@ pub struct Cached { } impl WrappedData for CachedModule { - type Wrap<'a, T: Unit, Base: IsBasePtr> = Mods::Wrap<'a, T, Base>; + type Wrap<'a, T: Unit, Base: IsBasePtr> = Guard<'a, Mods::Wrap<'static, T, Base>>; #[inline] fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { - // Guard::new(CowMut::Owned(self.modules.wrap_in_base(base))) - self.modules.wrap_in_base(base) + Guard::new(CowMut::Owned(self.modules.wrap_in_base(base))) } #[inline] - fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base { + fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base { Mods::wrapped_as_base(wrap) } #[inline] - fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base { + fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b mut Self::Wrap<'a, T, Base>, + ) -> &'b mut Base { Mods::wrapped_as_base_mut(wrap) } } @@ -48,6 +54,7 @@ where CachedModule { modules: Mods::new(), cache: RefCell::new(CacheType::default()), + cache3: Default::default(), pd: PhantomData, cursor: Default::default(), } @@ -60,6 +67,7 @@ where pub struct CachedModule { pub modules: Mods, pub cache: RefCell, + pub cache3: crate::LockedMap>, pub(crate) pd: PhantomData, cursor: Cell, // would move this to `Cache`, however -> RefCell; TODO: maybe add a Cursor Module } @@ -144,21 +152,43 @@ impl OnDropBuffer for CachedModule CachedModule +where + Mods: WrappedData, + SimpleDevice: Device, +{ + pub fn get( + &'a self, + id: u64, + ) -> State>>> + where + D: Device, + T: 'static, + S: Shape, + { + let entry = self.cache3.get_mut(&id)?; + let entry = RefMut::map(entry, |x| { + x.downcast_mut::>>() + .unwrap() + }); + Ok(Guard::new(CowMut::BorrowedMut(entry))) + } +} + // TODO: a more general OnDropBuffer => "Module" impl<'a, CacheType, T, Mods, D, SimpleDevice, S: Shape> Retrieve<'a, D, T, S> for CachedModule where T: Unit + 'static, - Mods: Retrieve<'a, D, T, S>, - D: Device + IsShapeIndep + Cursor + 'static, - D::Base: ShallowCopy + 'static, - D::Data<'a, T, S>: ShallowCopy + 'static, + Mods: Retrieve<'static, D, T, S>, + D: Device + IsShapeIndep + Cursor, + D::Base: 'static, SimpleDevice: Device, CacheType: Cache, { #[inline] unsafe fn retrieve_entry( - &self, + &'a self, device: &D, len: usize, _parents: &impl Parents, @@ -166,24 +196,40 @@ where where D: Alloc, { - let retrieved = Ok(self.wrap_in_base(self.cache.borrow_mut().get( - device, - device.cursor() as UniqueId, - len, - |_cursor, _base| {}, - ))); - unsafe { device.bump_cursor() }; - retrieved + let id = device.cursor() as UniqueId; + match self.get::(id) { + Ok(out) => Ok(out), + Err(state) => match state { + LockInfo::Locked => panic!("Locked!!"), + LockInfo::None => { + self.cache3 + .insert(id, Box::new(self.modules.retrieve(device, len, _parents))); + Ok(self.get::(id).unwrap()) + } + }, + } + // let retrieved = Ok(self.wrap_in_base(self.cache.borrow_mut().get( + // device, + // device.cursor() as UniqueId, + // len, + // |_cursor, _base| {}, + // ))); + // unsafe { device.bump_cursor() }; + // retrieved } #[inline] - fn on_retrieve_finish(&self, len: usize, parents: impl Parents, retrieved_buf: &Buffer) - where + fn on_retrieve_finish( + &self, + len: usize, + parents: impl Parents, + retrieved_buf: &Buffer, + ) where D: Alloc, { self.modules.on_retrieve_finish(len, parents, retrieved_buf) } - + unsafe fn retrieve( &self, _device: &D, @@ -192,7 +238,7 @@ where ) -> crate::Result::Base>> where S: Shape, - D: Device + Alloc + D: Device + Alloc, { panic!("Modules retrieve calls are in the wrong order. Cached module requires to be called via 'retrieve_entry'") } @@ -280,6 +326,7 @@ impl AddLayer for Cached<() crate::CachedModule { modules: inner_mods, cache: Default::default(), + cache3: Default::default(), pd: core::marker::PhantomData, cursor: Default::default(), } @@ -475,10 +522,10 @@ mod tests { #[test] fn test_cached_return_retrieve() { // invalid! - let _x = { + { let device = CPU::>::new(); // let buf: Buffer = device.retrieve(10, ()); - unsafe { Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &()) } + unsafe { Retrieve::<_, f32, ()>::retrieve_entry(&device.modules, &device, 10, &()) }; }; } diff --git a/src/modules/fork.rs b/src/modules/fork.rs index 4a257451..f42504ce 100644 --- a/src/modules/fork.rs +++ b/src/modules/fork.rs @@ -41,7 +41,9 @@ impl WrappedData for Fork { } #[inline] - fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { + fn wrapped_as_base_mut( + wrap: &mut Self::Wrap, + ) -> &mut Base { Mods::wrapped_as_base_mut(wrap) } } diff --git a/src/modules/graph.rs b/src/modules/graph.rs index 909fc9b0..6f6f0c44 100644 --- a/src/modules/graph.rs +++ b/src/modules/graph.rs @@ -39,7 +39,9 @@ impl WrappedData for Graph { } #[inline] - fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { + fn wrapped_as_base_mut( + wrap: &mut Self::Wrap, + ) -> &mut Base { Mods::wrapped_as_base_mut(wrap) } } diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 5ee8048f..09cd121c 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -7,7 +7,7 @@ use core::{ }; use crate::{ - flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData + flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData, }; #[derive(Debug, Default)] @@ -33,7 +33,9 @@ impl WrappedData for Lazy<'_, Mods, T2> { } #[inline] - fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { + fn wrapped_as_base_mut( + wrap: &mut Self::Wrap, + ) -> &mut Base { Mods::wrapped_as_base_mut(wrap.maybe_data.data_mut().expect(MISSING_DATA)) } } diff --git a/src/wrapper.rs b/src/wrapper.rs index fe1bfed3..335733dd 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,13 +1,17 @@ use crate::{HasId, IsBasePtr, PtrType, Unit}; pub trait WrappedData { - type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a; + type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a; fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base>; #[track_caller] - fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base; + fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base; #[track_caller] - fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base; + fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b mut Self::Wrap<'a, T, Base>, + ) -> &'b mut Base; } #[macro_export]