Skip to content

Commit

Permalink
Merge pull request #1 from jsatka/main
Browse files Browse the repository at this point in the history
Add Buffer type field as PhantomData.
  • Loading branch information
kenba authored Jan 15, 2021
2 parents 2598083 + 00f13fc commit 8544ec5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 37 deletions.
45 changes: 23 additions & 22 deletions src/command_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub use cl3::command_queue::*;

use super::device::Device;
use super::event::Event;
use super::memory::Buffer;

use cl3::types::{
cl_bool, cl_command_queue, cl_command_queue_properties, cl_context, cl_device_id, cl_event,
Expand Down Expand Up @@ -136,15 +137,15 @@ impl CommandQueue {

pub fn enqueue_read_buffer<T>(
&self,
buffer: cl_mem,
buffer: &Buffer<T>,
blocking_read: cl_bool,
offset: size_t,
data: &mut [T],
event_wait_list: &[cl_event],
) -> Result<Event, cl_int> {
let event = enqueue_read_buffer(
self.queue,
buffer,
buffer.get(),
blocking_read,
offset,
(data.len() * mem::size_of::<T>()) as size_t,
Expand All @@ -159,9 +160,9 @@ impl CommandQueue {
Ok(Event::new(event))
}

pub fn enqueue_read_buffer_rect(
pub fn enqueue_read_buffer_rect<T>(
&self,
buffer: cl_mem,
buffer: &Buffer<T>,
blocking_read: cl_bool,
buffer_origin: *const size_t,
host_origin: *const size_t,
Expand All @@ -175,7 +176,7 @@ impl CommandQueue {
) -> Result<Event, cl_int> {
let event = enqueue_read_buffer_rect(
self.queue,
buffer,
buffer.get(),
blocking_read,
buffer_origin,
host_origin,
Expand All @@ -197,15 +198,15 @@ impl CommandQueue {

pub fn enqueue_write_buffer<T>(
&self,
buffer: cl_mem,
buffer: &Buffer<T>,
blocking_write: cl_bool,
offset: size_t,
data: &[T],
event_wait_list: &[cl_event],
) -> Result<Event, cl_int> {
let event = enqueue_write_buffer(
self.queue,
buffer,
buffer.get(),
blocking_write,
offset,
(data.len() * mem::size_of::<T>()) as size_t,
Expand All @@ -220,9 +221,9 @@ impl CommandQueue {
Ok(Event::new(event))
}

pub fn enqueue_write_buffer_rect(
pub fn enqueue_write_buffer_rect<T>(
&self,
buffer: cl_mem,
buffer: &Buffer<T>,
blocking_write: cl_bool,
buffer_origin: *const size_t,
host_origin: *const size_t,
Expand All @@ -236,7 +237,7 @@ impl CommandQueue {
) -> Result<Event, cl_int> {
let event = enqueue_write_buffer_rect(
self.queue,
buffer,
buffer.get(),
blocking_write,
buffer_origin,
host_origin,
Expand All @@ -258,15 +259,15 @@ impl CommandQueue {

pub fn enqueue_fill_buffer<T>(
&self,
buffer: cl_mem,
buffer: &Buffer<T>,
pattern: &[T],
offset: size_t,
size: size_t,
event_wait_list: &[cl_event],
) -> Result<Event, cl_int> {
let event = enqueue_fill_buffer(
self.queue,
buffer,
buffer.get(),
pattern.as_ptr() as cl_mem,
pattern.len() * mem::size_of::<T>(),
offset,
Expand All @@ -281,19 +282,19 @@ impl CommandQueue {
Ok(Event::new(event))
}

pub fn enqueue_copy_buffer(
pub fn enqueue_copy_buffer<T>(
&self,
src_buffer: cl_mem,
dst_buffer: cl_mem,
src_buffer: &Buffer<T>,
dst_buffer: &Buffer<T>,
src_offset: size_t,
dst_offset: size_t,
size: size_t,
event_wait_list: &[cl_event],
) -> Result<Event, cl_int> {
let event = enqueue_copy_buffer(
self.queue,
src_buffer,
dst_buffer,
src_buffer.get(),
dst_buffer.get(),
src_offset,
dst_offset,
size,
Expand All @@ -306,10 +307,10 @@ impl CommandQueue {
)?;
Ok(Event::new(event))
}
pub fn enqueue_copy_buffer_rect(
pub fn enqueue_copy_buffer_rect<T>(
&self,
src_buffer: cl_mem,
dst_buffer: cl_mem,
src_buffer: &Buffer<T>,
dst_buffer: &Buffer<T>,
src_origin: *const size_t,
dst_origin: *const size_t,
region: *const size_t,
Expand All @@ -321,8 +322,8 @@ impl CommandQueue {
) -> Result<Event, cl_int> {
let event = enqueue_copy_buffer_rect(
self.queue,
src_buffer,
dst_buffer,
src_buffer.get(),
dst_buffer.get(),
src_origin,
dst_origin,
region,
Expand Down
26 changes: 17 additions & 9 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::marker::PhantomData;

pub use cl3::memory::*;

use super::context::Context;
Expand Down Expand Up @@ -85,19 +87,21 @@ pub fn get_mem_properties(memobj: cl_mem) -> Result<Vec<cl_ulong>, cl_int> {

/// An OpenCL buffer.
/// Implements the Drop trait to call release_mem_object when the object is dropped.
pub struct Buffer {
pub struct Buffer<T> {
buffer: cl_mem,
#[doc(hidden)]
pub _type: PhantomData<T>,
}

impl Drop for Buffer {
impl<T> Drop for Buffer<T> {
fn drop(&mut self) {
memory::release_mem_object(self.buffer).unwrap();
}
}

impl Buffer {
pub fn new(buffer: cl_mem) -> Buffer {
Buffer { buffer }
impl<T> Buffer<T> {
pub fn new(buffer: cl_mem) -> Buffer<T> {
Buffer { buffer, _type: PhantomData }
}

/// Create a Buffer for a context.
Expand All @@ -112,12 +116,12 @@ impl Buffer {
///
/// returns a Result containing the new OpenCL buffer object
/// or the error code from the OpenCL C API function.
pub fn create<T>(
pub fn create(
context: &Context,
flags: cl_mem_flags,
count: size_t,
host_ptr: *mut c_void,
) -> Result<Buffer, cl_int> {
) -> Result<Buffer<T>, cl_int> {
let buffer =
memory::create_buffer(context.get(), flags, count * mem::size_of::<T>(), host_ptr)?;
Ok(Buffer::new(buffer))
Expand All @@ -144,7 +148,7 @@ impl Buffer {
flags: cl_mem_flags,
count: size_t,
host_ptr: *mut c_void,
) -> Result<Buffer, cl_int> {
) -> Result<Buffer<T>, cl_int> {
let buffer = memory::create_buffer_with_properties(
context.get(),
properties,
Expand All @@ -171,7 +175,7 @@ impl Buffer {
flags: cl_mem_flags,
buffer_create_type: cl_buffer_create_type,
buffer_create_info: *const c_void,
) -> Result<Buffer, cl_int> {
) -> Result<Buffer<T>, cl_int> {
let buffer =
memory::create_sub_buffer(self.buffer, flags, buffer_create_type, buffer_create_info)?;
Ok(Buffer::new(buffer))
Expand All @@ -180,6 +184,10 @@ impl Buffer {
pub fn get(&self) -> cl_mem {
self.buffer
}

pub fn cast<NewT>(&self) -> Buffer<NewT> {
Buffer::new(self.buffer)
}
}

/// An OpenCL image.
Expand Down
12 changes: 6 additions & 6 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ fn test_opencl_1_2_example() {
}

// Create OpenCL device buffers
let x = Buffer::create::<cl_float>(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
let x = Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let y = Buffer::create::<cl_float>(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
let y = Buffer::<cl_float>::create(&context, CL_MEM_WRITE_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();
let z = Buffer::create::<cl_float>(&context, CL_MEM_READ_ONLY, ARRAY_SIZE, ptr::null_mut())
let z = Buffer::<cl_float>::create(&context, CL_MEM_READ_ONLY, ARRAY_SIZE, ptr::null_mut())
.unwrap();

let queue = context.default_queue();
Expand All @@ -113,12 +113,12 @@ fn test_opencl_1_2_example() {

// Blocking write
let _x_write_event = queue
.enqueue_write_buffer(x.get(), CL_TRUE, 0, &ones, &events)
.enqueue_write_buffer(&x, CL_TRUE, 0, &ones, &events)
.unwrap();

// Non-blocking write, wait for y_write_event
let y_write_event = queue
.enqueue_write_buffer(y.get(), CL_FALSE, 0, &sums, &events)
.enqueue_write_buffer(&y, CL_FALSE, 0, &sums, &events)
.unwrap();

// Convert to CString for get_kernel function
Expand Down Expand Up @@ -147,7 +147,7 @@ fn test_opencl_1_2_example() {
// after the kernel event completes.
let mut results: [cl_float; ARRAY_SIZE] = [0.0; ARRAY_SIZE];
let _event = queue
.enqueue_read_buffer(z.get(), CL_FALSE, 0, &mut results, &events)
.enqueue_read_buffer(&z, CL_FALSE, 0, &mut results, &events)
.unwrap();
events.clear();

Expand Down

0 comments on commit 8544ec5

Please sign in to comment.