Skip to content

Commit

Permalink
Make mutability explicit
Browse files Browse the repository at this point in the history
When a `cl_mem` is passed around, sometimes a mutable reference is needed.
Add a `get_mut()` method to the `ClMem` trait, so that the mutability
requirements automatically bubble up the API.

Closes #26.
  • Loading branch information
vmx committed Aug 11, 2021
1 parent ea2f71b commit 011ffe3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
40 changes: 20 additions & 20 deletions src/command_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ impl CommandQueue {

pub fn enqueue_write_buffer<T>(
&self,
buffer: &Buffer<T>,
buffer: &mut Buffer<T>,
blocking_write: cl_bool,
offset: size_t,
data: &[T],
event_wait_list: &[cl_event],
) -> Result<Event> {
let event = enqueue_write_buffer(
self.queue,
buffer.get(),
buffer.get_mut(),
blocking_write,
offset,
(data.len() * mem::size_of::<T>()) as size_t,
Expand All @@ -259,7 +259,7 @@ impl CommandQueue {

pub fn enqueue_write_buffer_rect<T>(
&self,
buffer: &Buffer<T>,
buffer: &mut Buffer<T>,
blocking_write: cl_bool,
buffer_origin: *const size_t,
host_origin: *const size_t,
Expand All @@ -273,7 +273,7 @@ impl CommandQueue {
) -> Result<Event> {
let event = enqueue_write_buffer_rect(
self.queue,
buffer.get(),
buffer.get_mut(),
blocking_write,
buffer_origin,
host_origin,
Expand All @@ -295,15 +295,15 @@ impl CommandQueue {

pub fn enqueue_fill_buffer<T>(
&self,
buffer: &Buffer<T>,
buffer: &mut Buffer<T>,
pattern: &[T],
offset: size_t,
size: size_t,
event_wait_list: &[cl_event],
) -> Result<Event> {
let event = enqueue_fill_buffer(
self.queue,
buffer.get(),
buffer.get_mut(),
pattern.as_ptr() as cl_mem,
pattern.len() * mem::size_of::<T>(),
offset,
Expand All @@ -321,7 +321,7 @@ impl CommandQueue {
pub fn enqueue_copy_buffer<T>(
&self,
src_buffer: &Buffer<T>,
dst_buffer: &Buffer<T>,
dst_buffer: &mut Buffer<T>,
src_offset: size_t,
dst_offset: size_t,
size: size_t,
Expand All @@ -330,7 +330,7 @@ impl CommandQueue {
let event = enqueue_copy_buffer(
self.queue,
src_buffer.get(),
dst_buffer.get(),
dst_buffer.get_mut(),
src_offset,
dst_offset,
size,
Expand All @@ -346,7 +346,7 @@ impl CommandQueue {
pub fn enqueue_copy_buffer_rect<T>(
&self,
src_buffer: &Buffer<T>,
dst_buffer: &Buffer<T>,
dst_buffer: &mut Buffer<T>,
src_origin: *const size_t,
dst_origin: *const size_t,
region: *const size_t,
Expand All @@ -359,7 +359,7 @@ impl CommandQueue {
let event = enqueue_copy_buffer_rect(
self.queue,
src_buffer.get(),
dst_buffer.get(),
dst_buffer.get_mut(),
src_origin,
dst_origin,
region,
Expand Down Expand Up @@ -409,7 +409,7 @@ impl CommandQueue {

pub fn enqueue_write_image(
&self,
image: &Image,
image: &mut Image,
blocking_write: cl_bool,
origin: *const size_t,
region: *const size_t,
Expand All @@ -420,7 +420,7 @@ impl CommandQueue {
) -> Result<Event> {
let event = enqueue_write_image(
self.queue,
image.get(),
image.get_mut(),
blocking_write,
origin,
region,
Expand All @@ -439,15 +439,15 @@ impl CommandQueue {

pub fn enqueue_fill_image(
&self,
image: &Image,
image: &mut Image,
fill_color: *const c_void,
origin: *const size_t,
region: *const size_t,
event_wait_list: &[cl_event],
) -> Result<Event> {
let event = enqueue_fill_image(
self.queue,
image.get(),
image.get_mut(),
fill_color,
origin,
region,
Expand All @@ -464,7 +464,7 @@ impl CommandQueue {
pub fn enqueue_copy_image(
&self,
src_image: &Image,
dst_image: &Image,
dst_image: &mut Image,
src_origin: *const size_t,
dst_origin: *const size_t,
region: *const size_t,
Expand All @@ -473,7 +473,7 @@ impl CommandQueue {
let event = enqueue_copy_image(
self.queue,
src_image.get(),
dst_image.get(),
dst_image.get_mut(),
src_origin,
dst_origin,
region,
Expand All @@ -490,7 +490,7 @@ impl CommandQueue {
pub fn enqueue_copy_image_to_buffer<T>(
&self,
src_image: &Image,
dst_buffer: &Buffer<T>,
dst_buffer: &mut Buffer<T>,
src_origin: *const size_t,
region: *const size_t,
dst_offset: size_t,
Expand All @@ -499,7 +499,7 @@ impl CommandQueue {
let event = enqueue_copy_image_to_buffer(
self.queue,
src_image.get(),
dst_buffer.get(),
dst_buffer.get_mut(),
src_origin,
region,
dst_offset,
Expand All @@ -516,7 +516,7 @@ impl CommandQueue {
pub fn enqueue_copy_buffer_to_image<T>(
&self,
src_buffer: &Buffer<T>,
dst_image: &Image,
dst_image: &mut Image,
src_offset: size_t,
dst_origin: *const size_t,
region: *const size_t,
Expand All @@ -525,7 +525,7 @@ impl CommandQueue {
let event = enqueue_copy_buffer_to_image(
self.queue,
src_buffer.get(),
dst_image.get(),
dst_image.get_mut(),
src_offset,
dst_origin,
region,
Expand Down
4 changes: 2 additions & 2 deletions src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,15 @@ mod tests {
const ARRAY_SIZE: usize = 1024;
let ones: [cl_float; ARRAY_SIZE] = [1.0; ARRAY_SIZE];

let buffer =
let mut buffer =
Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();

let events: Vec<cl_event> = Vec::default();

// Non-blocking write, wait for event
let event = queue
.enqueue_write_buffer(&buffer, CL_NON_BLOCKING, 0, &ones, &events)
.enqueue_write_buffer(&mut buffer, CL_NON_BLOCKING, 0, &ones, &events)
.unwrap();

// Set a callback_function on the event (i.e. write) being completed.
Expand Down
14 changes: 14 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ use std::mem;
pub trait ClMem {
fn get(&self) -> cl_mem;

fn get_mut(&mut self) -> cl_mem;

fn mem_type(&self) -> Result<cl_mem_object_type> {
Ok(memory::get_mem_object_info(self.get(), MemInfo::CL_MEM_TYPE)?.to_uint())
}
Expand Down Expand Up @@ -128,6 +130,10 @@ impl<T> ClMem for Buffer<T> {
fn get(&self) -> cl_mem {
self.buffer
}

fn get_mut(&mut self) -> cl_mem {
self.buffer
}
}

impl<T> Drop for Buffer<T> {
Expand Down Expand Up @@ -282,6 +288,10 @@ impl ClMem for Image {
fn get(&self) -> cl_mem {
self.image
}

fn get_mut(&mut self) -> cl_mem {
self.image
}
}

impl Drop for Image {
Expand Down Expand Up @@ -766,6 +776,10 @@ impl ClMem for Pipe {
fn get(&self) -> cl_mem {
self.pipe
}

fn get_mut(&mut self) -> cl_mem {
self.pipe
}
}

impl Drop for Pipe {
Expand Down
14 changes: 8 additions & 6 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,25 @@ fn test_opencl_1_2_example() {
}

// Create OpenCL device buffers
let x = Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let y = Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let mut x =
Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let mut y =
Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let z = Buffer::<cl_float>::create(&context, CL_MEM_READ_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();

let mut events: Vec<cl_event> = Vec::default();

// Blocking write
let _x_write_event = queue
.enqueue_write_buffer(&x, CL_BLOCKING, 0, &ones, &events)
.enqueue_write_buffer(&mut x, CL_BLOCKING, 0, &ones, &events)
.unwrap();

// Non-blocking write, wait for y_write_event
let y_write_event = queue
.enqueue_write_buffer(&y, CL_NON_BLOCKING, 0, &sums, &events)
.enqueue_write_buffer(&mut y, CL_NON_BLOCKING, 0, &sums, &events)
.unwrap();

// a value for the kernel function
Expand Down

0 comments on commit 011ffe3

Please sign in to comment.