Skip to content

Commit

Permalink
Update LockedMap
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 1, 2024
1 parent ccdef89 commit 9ebd957
Show file tree
Hide file tree
Showing 28 changed files with 215 additions and 104 deletions.
4 changes: 3 additions & 1 deletion examples/custom_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ impl<Mods: WrappedData> WrappedData for CustomModule<Mods> {
}

#[inline]
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
16 changes: 12 additions & 4 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use core::{
};

use crate::{
flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, IsBasePtr, OnDropBuffer, PtrType, ShallowCopy, Unit, WrappedData
flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, IsBasePtr, OnDropBuffer, PtrType,
ShallowCopy, Unit, WrappedData,
};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -60,7 +61,10 @@ impl Device for () {
}

#[inline(always)]
fn base_to_data<'a, T: Unit, S: crate::Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
fn base_to_data<'a, T: Unit, S: crate::Shape>(
&self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
base
}

Expand Down Expand Up @@ -114,12 +118,16 @@ impl WrappedData for () {
}

#[inline]
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b Self::Wrap<'a, T, Base>,
) -> &'b Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base {
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
wrap
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/cache/locking.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod guard;
mod locked_array;
mod locked_map;
mod guard;
pub use guard::*;
pub use locked_array::*;
pub use locked_map::*;
Expand All @@ -11,4 +11,4 @@ pub type State<T> = Result<T, LockInfo>;
pub enum LockInfo {
Locked,
None,
}
}
10 changes: 5 additions & 5 deletions src/cache/locking/guard.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use core::{mem::ManuallyDrop, ops::{Deref, DerefMut}};
use core::{
mem::ManuallyDrop,
ops::{Deref, DerefMut},
};

use crate::{CowMutCell, HasId, PtrType, ShallowCopy};

Expand All @@ -18,9 +21,7 @@ impl<'a, T> Guard<'a, T> {
F: FnOnce(CowMutCell<'a, T>) -> CowMutCell<'a, U>,
{
let Guard { data } = self;
Guard {
data: f(data),
}
Guard { data: f(data) }
}
}

Expand Down Expand Up @@ -69,4 +70,3 @@ impl<'a, T> ShallowCopy for Guard<'a, T> {
todo!()
}
}

11 changes: 5 additions & 6 deletions src/cache/locking/locked_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ impl<T, const N: usize> LockedArray<T, N> {
if data.is_none() {
return State::Err(LockInfo::None);
}
return State::Ok(Guard::new(CowMutCell::Borrowed(Ref::map(
data,
|data| data.as_ref().unwrap(),
))));
return State::Ok(Guard::new(CowMutCell::Borrowed(Ref::map(data, |data| {
data.as_ref().unwrap()
}))));
}
Err(_) => return State::Err(LockInfo::Locked),
}
Expand Down Expand Up @@ -84,7 +83,7 @@ mod tests {
data1.push(2);
assert_eq!(data1.as_slice(), [1, 2]);
}

#[cfg(feature = "std")]
#[test]
#[should_panic]
Expand All @@ -93,7 +92,7 @@ mod tests {
locked_array.set(1, vec![10]);
locked_array.set(1, vec![10]);
}

#[cfg(feature = "std")]
#[test]
fn test_get_not_set() {
Expand Down
21 changes: 16 additions & 5 deletions src/cache/locking/locked_map.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::ops::Deref;
use std::{
cell::{Ref, RefCell, RefMut, UnsafeCell},
collections::HashMap,
Expand All @@ -7,7 +8,7 @@ use std::{
use crate::{LockInfo, State};

pub struct LockedMap<K, V, S = RandomState> {
data: UnsafeCell<HashMap<K, Box<RefCell<V>>, S>>,
data: RefCell<HashMap<K, Box<RefCell<V>>, S>>,
}

impl<K, T, S: Default> Default for LockedMap<K, T, S> {
Expand All @@ -27,11 +28,22 @@ impl<K, T, S: Default> LockedMap<K, T, S> {
}
}
impl<K, T, S: BuildHasher> LockedMap<K, T, S> {
#[inline]
pub fn len(&self) -> usize {
self.data.borrow().len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.data.borrow().is_empty()
}

pub fn insert(&self, id: K, data: T)
where
K: Eq + Hash,
{
let map = unsafe { &mut *self.data.get() };
// let map = unsafe { &mut *self.data.get() };
let mut map = self.data.borrow_mut();
if map.contains_key(&id) {
panic!()
}
Expand All @@ -42,7 +54,7 @@ impl<K, T, S: BuildHasher> LockedMap<K, T, S> {
where
K: Eq + Hash,
{
let map = unsafe { &mut *self.data.get() };
let map = unsafe { &*self.data.as_ptr() };
let entry = map.get(id).ok_or(LockInfo::None)?;
(&**entry).try_borrow().map_err(|_| LockInfo::Locked)
}
Expand All @@ -51,7 +63,7 @@ impl<K, T, S: BuildHasher> LockedMap<K, T, S> {
where
K: Eq + Hash,
{
let map = unsafe { &mut *self.data.get() };
let map = unsafe { &*self.data.as_ptr() };
let entry = map.get(id).ok_or(LockInfo::None)?;
(&**entry).try_borrow_mut().map_err(|_| LockInfo::Locked)
}
Expand All @@ -67,7 +79,6 @@ mod tests {

#[test]
fn test_locked_boxed() {

let locked_map = LockedMap::<UniqueId, Vec<u32>, BuildHasherDefault<NoHasher>>::new();

locked_map.insert(0, vec![1, 2, 3, 4]);
Expand Down
4 changes: 2 additions & 2 deletions src/cache/owned_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait Cache {
) -> D::Base<T, S>
where
T: Unit,
D: Alloc<T> + 'static,
D::Base<T, S>: ShallowCopy + 'static,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy,
S: Shape;
}
40 changes: 23 additions & 17 deletions src/cache/owned_cache/fast_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ use core::{any::Any, hash::BuildHasherDefault};
use std::{collections::HashMap, sync::Arc};

use crate::{
flag::AllocFlag, Alloc, Cache, Device, NoHasher, PtrType, ShallowCopy, Shape, UniqueId, Unit,
flag::AllocFlag, Alloc, Cache, Device, LockedMap, NoHasher, PtrType, ShallowCopy, Shape, UniqueId, Unit
};

#[derive(Clone)]
pub struct FastCache {
pub nodes: HashMap<UniqueId, Arc<dyn Any>, BuildHasherDefault<NoHasher>>,
pub nodes: LockedMap<UniqueId, Arc<dyn Any>, BuildHasherDefault<NoHasher>>,
}

impl Default for FastCache {
Expand All @@ -28,11 +27,11 @@ impl Cache for FastCache {
) -> D::Base<T, S>
where
T: Unit,
D: Alloc<T> + 'static,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
self.get(device, id, len, new_buf_callback)
self.get(device, id, len, new_buf_callback).unwrap()
}
}

Expand All @@ -53,16 +52,16 @@ impl FastCache {
id: UniqueId,
len: usize,
new_buf_callback: impl FnMut(UniqueId, &D::Base<T, S>),
) -> D::Base<T, S>
) -> crate::Result<D::Base<T, S>>
where
T: Unit,
D: Alloc<T> + 'static,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
let maybe_allocated = self.nodes.get(&id);
match maybe_allocated {
Some(data) => {
Ok(data) => {
let data = unsafe {
data.downcast_ref::<D::Base<T, S>>()
.expect("Invalid request for data type!")
Expand All @@ -71,32 +70,39 @@ impl FastCache {

// TODO: not necessary, could add length to hashmap
assert_eq!(data.size(), len, "Data size mismatch! Did you use e.g. if conditions in a (cursor) loop retrieving buffers with a different size?");
data
Ok(data)
}
None => unsafe { self.add_node(device, id, len, new_buf_callback) },
Err(e) => {
match e {
crate::LockInfo::Locked => panic!("should return error"),
crate::LockInfo::None => {
unsafe { self.add_node(device, id, len, new_buf_callback) }
},
}
},
}
}

unsafe fn add_node<T, S, D>(
&mut self,
&self,
device: &D,
id: UniqueId,
len: usize,
mut callback: impl FnMut(UniqueId, &D::Base<T, S>),
) -> <D as Device>::Base<T, S>
) -> crate::Result<<D as Device>::Base<T, S>>
where
T: Unit,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
let data = device.alloc::<S>(len, AllocFlag::None).unwrap();
let data = device.alloc::<S>(len, AllocFlag::None)?;
let shallow_data = unsafe { data.shallow() };

callback(id, &shallow_data);
self.nodes.insert(id, Arc::new(data));

shallow_data
Ok(shallow_data)
}
}

Expand All @@ -115,12 +121,12 @@ mod tests {

assert_eq!(cache.nodes.len(), 0);

let out = unsafe { cache.add_node::<f32, (), _>(&device, 0, 10, |_a, _b| ()) };
let out = unsafe { cache.add_node::<f32, (), _>(&device, 0, 10, |_a, _b| ()) }.unwrap();

assert_eq!(cache.nodes.len(), 1);
assert_eq!(out.len, 10);

let out1 = unsafe { cache.get::<f32, (), _>(&device, 1, 10, |_a, _b| ()) };
let out1 = unsafe { cache.get::<f32, (), _>(&device, 1, 10, |_a, _b| ()) }.unwrap();
assert_ne!(out.ptr, out1.ptr);
}

Expand All @@ -135,7 +141,7 @@ mod tests {

let mut prev = None;
for _ in device.range(0..1000) {
let out3 = unsafe { cache.get::<f32, (), _>(&device, 0, 10, |_a, _b| ()) };
let out3 = unsafe { cache.get::<f32, (), _>(&device, 0, 10, |_a, _b| ()) }.unwrap();
if prev.is_none() {
prev = Some(out3.ptr);
}
Expand Down
4 changes: 2 additions & 2 deletions src/cache/owned_cache/length_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl Cache for LengthCache {
) -> D::Base<T, S>
where
T: Unit,
D: Alloc<T> + 'static,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
Expand Down Expand Up @@ -54,7 +54,7 @@ impl LengthCache {
) -> D::Base<T, S>
where
T: Unit,
D: Alloc<T> + 'static,
D: Alloc<T>,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
Expand Down
2 changes: 1 addition & 1 deletion src/cow_mut.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::ops::{Deref, DerefMut};
use core::cell::{Ref, RefMut};
use core::ops::{Deref, DerefMut};

pub type CowMutCell<'a, T> = CowMut<T, RefMut<'a, T>, Ref<'a, T>>;
pub type CowMutRef<'a, T> = CowMut<T, &'a T, &'a mut T>;
Expand Down
5 changes: 3 additions & 2 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ pub trait Device: OnDropBuffer + Sized {
&self,
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S>;
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(data: &'b Self::Data<'a, T, S>)
-> &'b Self::Wrap<'a, T, Self::Base<T, S>>;
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>>;
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>>;
Expand Down
7 changes: 6 additions & 1 deletion src/devices/cpu/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ where
D::Base<T, S>: Deref<Target = [T]>,
S: Shape,
{
type Read<'a> = &'a [T] where T: 'a, D: 'a, S: 'a;
type Read<'a>
= &'a [T]
where
T: 'a,
D: 'a,
S: 'a;

#[inline]
fn read<'a>(&self, buf: &'a D::Base<T, S>) -> Self::Read<'a>
Expand Down
3 changes: 2 additions & 1 deletion src/devices/cuda/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pass_down_add_operation!(CUDA);
pass_down_exec_now!(CUDA);

impl<Mods: OnDropBuffer, T: Unit + Default + Clone, S: Shape> Read<T, S> for CUDA<Mods> {
type Read<'a> = Vec<T>
type Read<'a>
= Vec<T>
where
T: 'a,
CUDA<Mods>: 'a;
Expand Down
10 changes: 8 additions & 2 deletions src/devices/opencl/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,15 @@ where
S: Shape,
{
#[cfg(not(unified_cl))]
type Read<'a> = Vec<T> where T: 'a;
type Read<'a>
= Vec<T>
where
T: 'a;
#[cfg(unified_cl)]
type Read<'a> = &'a [T] where T: 'a;
type Read<'a>
= &'a [T]
where
T: 'a;

#[cfg(not(unified_cl))]
#[inline]
Expand Down
Loading

0 comments on commit 9ebd957

Please sign in to comment.