Skip to content

Commit

Permalink
test init behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
paleolimbot committed Jun 26, 2024
1 parent 05d053f commit 20fc135
Showing 1 changed file with 80 additions and 1 deletion.
81 changes: 80 additions & 1 deletion src/nanoarrow/device/cuda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,37 @@ class CudaStream {
CUstream hstream_;
};

class CudaEvent {
public:
CudaEvent(int64_t device_id) : device_id_(device_id), hevent_(nullptr) {}

ArrowErrorCode Init() {
CudaTemporaryContext ctx(device_id_);
if (!ctx.valid()) {
return EINVAL;
}

if (cuEventCreate(&hevent_, CU_EVENT_DEFAULT) != CUDA_SUCCESS) {
return EINVAL;
}

return NANOARROW_OK;
}

CUevent* get() { return &hevent_; }

void release() { hevent_ = nullptr; }

~CudaEvent() {
if (hevent_ != nullptr) {
cuEventDestroy(hevent_);
}
}

int64_t device_id_;
CUevent hevent_;
};

TEST(NanoarrowDeviceCuda, GetDevice) {
struct ArrowDevice* cuda = ArrowDeviceCuda(ARROW_DEVICE_CUDA, 0);
ASSERT_NE(cuda, nullptr);
Expand Down Expand Up @@ -241,7 +272,55 @@ TEST(NanoarrowDeviceCuda, DeviceCudaBufferCopy) {
}
}

TEST(NanoarrowDeviceCuda, DeviceCudaArrayInit) {}
TEST(NanoarrowDeviceCuda, DeviceCudaArrayInit) {
struct ArrowDevice* gpu = ArrowDeviceCuda(ARROW_DEVICE_CUDA, 0);

CudaStream stream(gpu->device_id);
ASSERT_EQ(stream.Init(), NANOARROW_OK);

CudaEvent event(gpu->device_id);
ASSERT_EQ(event.Init(), NANOARROW_OK);

struct ArrowDeviceArray device_array;
struct ArrowArray array;
array.release = nullptr;

// No provided sync event should result in a null sync event in the final array
ASSERT_EQ(ArrowArrayInitFromType(&array, NANOARROW_TYPE_INT32), NANOARROW_OK);
ASSERT_EQ(ArrowDeviceArrayInit(gpu, &device_array, &array, nullptr), NANOARROW_OK);
ASSERT_EQ(device_array.sync_event, nullptr);
ArrowArrayRelease(&device_array.array);

// Provided sync event should result in ownership of the event being taken by the
// device array.
device_array.sync_event = nullptr;
ASSERT_EQ(ArrowArrayInitFromType(&array, NANOARROW_TYPE_INT32), NANOARROW_OK);
ASSERT_EQ(ArrowDeviceArrayInit(gpu, &device_array, &array, event.get()), NANOARROW_OK);
ASSERT_EQ(*((CUevent*)device_array.sync_event), *event.get());
event.release();
ArrowArrayRelease(&device_array.array);

// Provided stream without provided event should result in an event created by and owned
// by the device array
device_array.sync_event = nullptr;
ASSERT_EQ(ArrowArrayInitFromType(&array, NANOARROW_TYPE_INT32), NANOARROW_OK);
ASSERT_EQ(ArrowDeviceArrayInitAsync(gpu, &device_array, &array, nullptr, stream.get()),
NANOARROW_OK);
ASSERT_NE(*(CUevent*)device_array.sync_event, nullptr);
ArrowArrayRelease(&device_array.array);

// Provided stream and sync event should result in the device array taking ownership
// and recording the event
ASSERT_EQ(event.Init(), NANOARROW_OK);
device_array.sync_event = nullptr;
ASSERT_EQ(ArrowArrayInitFromType(&array, NANOARROW_TYPE_INT32), NANOARROW_OK);
ASSERT_EQ(
ArrowDeviceArrayInitAsync(gpu, &device_array, &array, event.get(), stream.get()),
NANOARROW_OK);
ASSERT_EQ(*((CUevent*)device_array.sync_event), *event.get());
event.release();
ArrowArrayRelease(&device_array.array);
}

class StringTypeParameterizedTestFixture
: public ::testing::TestWithParam<std::tuple<ArrowDeviceType, enum ArrowType, bool>> {
Expand Down

0 comments on commit 20fc135

Please sign in to comment.