Skip to content

Commit

Permalink
Add WrappedCopy impls (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend authored Oct 13, 2024
1 parent 99d36a0 commit 46efcd5
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"]
default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "cuda", "vulkan", "stack"]

# default = ["cpu"]
# default = ["no-std"]
Expand Down
9 changes: 6 additions & 3 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::CPU;
use crate::{
flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId,
IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit,
WrappedData, WriteBuf, ZeroGrad,
WrappedCopy, WrappedData, WriteBuf, ZeroGrad,
};

pub use self::num::Num;
Expand Down Expand Up @@ -477,11 +477,14 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
pub fn to_dims<O: Shape>(self) -> Buffer<'a, T, D, O>
where
D: crate::ToDim<T, S, O>,
D::Data<T, S>: ShallowCopy,
D::Data<T, S>: WrappedCopy<Base = D::Base<T, S>>,
D::Base<T, S>: ShallowCopy,
{
let base = unsafe { (*self).shallow() };
let data = self.data.wrapped_copy(base);
let buf = ManuallyDrop::new(self);

let mut data = buf.device().to_dim(unsafe { buf.data.shallow() });
let mut data = buf.device().to_dim(data);
unsafe { data.set_flag(AllocFlag::None) };

Buffer {
Expand Down
11 changes: 10 additions & 1 deletion src/devices/cpu/cpu_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{

use std::alloc::handle_alloc_error;

use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy};
use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, WrappedCopy};

/// The pointer used for `CPU` [`Buffer`](crate::Buffer)s
#[derive(Debug)]
Expand Down Expand Up @@ -229,6 +229,15 @@ impl<T> PtrType for CPUPtr<T> {
}
}

impl<T> WrappedCopy for CPUPtr<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for CPUPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
11 changes: 10 additions & 1 deletion src/devices/cuda/cuda_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::api::{cu_read, cufree, cumalloc, CudaResult};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};
use core::marker::PhantomData;

/// The pointer used for `CUDA` [`Buffer`](crate::Buffer)s
Expand Down Expand Up @@ -76,6 +76,15 @@ impl<T> Drop for CUDAPtr<T> {
}
}

impl<T> WrappedCopy for CUDAPtr<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for CUDAPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
11 changes: 10 additions & 1 deletion src/devices/opencl/cl_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::HostPtr;

use min_cl::api::release_mem_object;

use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};

/// The pointer used for `OpenCL` [`Buffer`](crate::Buffer)s
#[derive(Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -59,6 +59,15 @@ impl<T> CLPtr<T> {
}
}

impl<T> WrappedCopy for CLPtr<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for CLPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
12 changes: 11 additions & 1 deletion src/devices/stack_array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::ops::{Deref, DerefMut};

use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy};
use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};

/// A possibly multi-dimensional array allocated on the stack.
/// It uses `S:`[`Shape`] to get the type of the array.
Expand Down Expand Up @@ -137,6 +137,16 @@ impl<S: Shape, T> HostPtr<T> for StackArray<S, T> {
}
}


impl<S: Shape, T> WrappedCopy for StackArray<S, T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<S: Shape, T> ShallowCopy for StackArray<S, T>
where
S::ARR<T>: Copy,
Expand Down
11 changes: 10 additions & 1 deletion src/devices/vulkan/vk_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use core::{
};
use std::rc::Rc;

use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy};
use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};

use super::{context::Context, submit_and_wait};

Expand Down Expand Up @@ -228,6 +228,15 @@ impl<T> VkArray<T> {
}
}

impl<T> WrappedCopy for VkArray<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for VkArray<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub trait Cursor {
}

#[inline]
fn cached(&self, cb: impl Fn())
fn cached(&self, cb: impl Fn())
where
Self: Sized,
{
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ pub trait Unit {} // useful for Sync and Send or 'static

impl<T> Unit for T {}

pub trait WrappedCopy {
type Base;
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self;
}

/// Used to shallow-copy a pointer. Use is discouraged.
pub trait ShallowCopy {
/// # Safety
Expand Down
18 changes: 17 additions & 1 deletion src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::marker::PhantomData;

use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedData};
use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedCopy, WrappedData};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ReqGradWrapper<Data, T> {
Expand Down Expand Up @@ -74,6 +74,22 @@ impl<Data: PtrType, T> PtrType for ReqGradWrapper<Data, T> {
}
}

impl<Data, T> WrappedCopy for ReqGradWrapper<Data, T>
where
Data: WrappedCopy<Base = T>,
{
type Base = T;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
Self {
requires_grad: self.requires_grad,
data: self.data.wrapped_copy(to_wrap),
_pd: PhantomData,
}
}
}

impl<Data, T> ShallowCopy for ReqGradWrapper<Data, T>
where
Data: ShallowCopy,
Expand Down
29 changes: 25 additions & 4 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use core::{
ops::{Deref, DerefMut},
};

use crate::{flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedData};
use crate::{
flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedCopy, WrappedData,
};

#[derive(Debug, Default)]
pub struct LazyWrapper<Data, T> {
Expand Down Expand Up @@ -42,7 +44,7 @@ impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
match self.maybe_data {
MaybeData::Data(ref data) => data.id(),
MaybeData::Id(id) => id,
MaybeData::None => unimplemented!()
MaybeData::None => unimplemented!(),
}
}
}
Expand All @@ -53,13 +55,14 @@ impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
match self.maybe_data {
MaybeData::Data(ref data) => data.size(),
MaybeData::Id(id) => id.len,
MaybeData::None => unimplemented!()
MaybeData::None => unimplemented!(),
}
}

#[inline]
fn flag(&self) -> AllocFlag {
self.maybe_data.data()
self.maybe_data
.data()
.map(|data| data.flag())
.unwrap_or(AllocFlag::Lazy)
}
Expand Down Expand Up @@ -101,6 +104,24 @@ impl<T, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
}
}

impl<Data, T> WrappedCopy for LazyWrapper<Data, T>
where
Data: WrappedCopy<Base = T>,
{
type Base = T;

fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
LazyWrapper {
maybe_data: match &self.maybe_data {
MaybeData::Data(data) => MaybeData::Data(data.wrapped_copy(to_wrap)),
MaybeData::Id(id) => MaybeData::Id(*id),
MaybeData::None => unimplemented!(),
},
_pd: PhantomData,
}
}
}

impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
8 changes: 4 additions & 4 deletions src/modules/lazy/wrapper/maybe_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ impl<Data> MaybeData<Data> {
MaybeData::None => None,
}
}

#[inline]
pub fn data_mut(&mut self) -> Option<&mut Data> {
match self {
MaybeData::Data(data) => Some(data),
MaybeData::Id(_id) => None,
MaybeData::None => None
MaybeData::None => None,
}
}

#[inline]
pub fn id(&self) -> Option<&Id> {
match self {
Expand All @@ -35,7 +35,7 @@ impl<Data> MaybeData<Data> {
MaybeData::None => None,
}
}

#[inline]
pub fn id_mut(&mut self) -> Option<&mut Id> {
match self {
Expand Down
6 changes: 2 additions & 4 deletions src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<'a, D: Cursor> CursorRangeIter<'a, D> {
pub fn previous_cursor(&self) -> &usize {
&self.previous_cursor
}

#[inline]
pub fn cursor_range(&self) -> &CursorRange<'a, D> {
&self.range
Expand Down Expand Up @@ -68,7 +68,6 @@ pub trait AsRange {
fn end(&self) -> usize;
}


// Implementing AsRange for standard Range (e.g., 0..10)
impl AsRange for Range<usize> {
#[inline]
Expand Down Expand Up @@ -173,7 +172,6 @@ impl AsRange for RangeToInclusive<usize> {
}
}


#[cfg(test)]
mod tests {
#[cfg(feature = "cpu")]
Expand Down Expand Up @@ -243,7 +241,7 @@ mod tests {
unsafe { device.bump_cursor() };
assert_eq!(device.cursor(), 10);
}
}
}

#[cfg(feature = "cpu")]
#[cfg(feature = "cached")]
Expand Down

0 comments on commit 46efcd5

Please sign in to comment.