Skip to content

Commit

Permalink
feature: adding blocking in a table convert function (#2625)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev authored Jan 10, 2024
1 parent e1daf23 commit 138ddd6
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 12 deletions.
7 changes: 7 additions & 0 deletions cpp/oneapi/dal/backend/transfer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ sycl::event scatter_host2device(sycl::queue& q,
std::int64_t dst_stride_in_bytes,
std::int64_t block_size_in_bytes,
const event_vector& deps = {});
sycl::event scatter_host2device_blocking(sycl::queue& q,
void* dst_device,
const void* src_host,
std::int64_t block_count,
std::int64_t dst_stride_in_bytes,
std::int64_t block_size_in_bytes,
const event_vector& deps = {});
#endif

} // namespace oneapi::dal::backend
80 changes: 76 additions & 4 deletions cpp/oneapi/dal/backend/transfer_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
#include <algorithm>

namespace oneapi::dal::backend {
namespace bk = dal::backend;
template <typename Float>
std::int64_t propose_block_size(const sycl::queue& q, const std::int64_t r) {
constexpr std::int64_t fsize = sizeof(Float);
return 0x10000l * (8 / fsize);
}

sycl::event gather_device2host(sycl::queue& q,
void* dst_host,
Expand Down Expand Up @@ -101,18 +107,19 @@ sycl::event scatter_host2device(sycl::queue& q,
auto scatter_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(copy_event);

byte_t* gathered_byte = reinterpret_cast<byte_t*>(gathered_device_unique.get());
byte_t* dst_byte = reinterpret_cast<byte_t*>(dst_device);
const byte_t* const gathered_byte =
reinterpret_cast<const byte_t*>(gathered_device_unique.get());
byte_t* const dst_byte = reinterpret_cast<byte_t*>(dst_device);

const std::int64_t required_local_size = 256;
const std::int64_t required_local_size = bk::device_max_wg_size(q);
const std::int64_t local_size = std::min(down_pow2(block_count), required_local_size);
const auto range = make_multiple_nd_range_1d(block_count, local_size);

cgh.parallel_for(range, [=](sycl::nd_item<1> id) {
const auto i = id.get_global_id();
if (i < block_count) {
// TODO: Unroll for optimization
for (int j = 0; j < block_size_in_bytes; j++) {
for (std::int64_t j = 0; j < block_size_in_bytes; ++j) {
dst_byte[i * dst_stride_in_bytes + j] =
gathered_byte[i * block_size_in_bytes + j];
}
Expand All @@ -127,4 +134,69 @@ sycl::event scatter_host2device(sycl::queue& q,
return sycl::event{};
}

sycl::event scatter_host2device_blocking(sycl::queue& q,
void* dst_device,
const void* src_host,
std::int64_t block_count,
std::int64_t dst_stride_in_bytes,
std::int64_t block_size_in_bytes,
const event_vector& deps) {
ONEDAL_ASSERT(dst_device);
ONEDAL_ASSERT(src_host);
ONEDAL_ASSERT(block_count > 0);
ONEDAL_ASSERT(dst_stride_in_bytes > 0);
ONEDAL_ASSERT(block_size_in_bytes > 0);
ONEDAL_ASSERT(dst_stride_in_bytes >= block_size_in_bytes);
ONEDAL_ASSERT(is_known_usm(q, dst_device));
ONEDAL_ASSERT_MUL_OVERFLOW(std::int64_t, block_count, block_size_in_bytes);
const auto gathered_device_unique =
make_unique_usm_device(q, block_count * block_size_in_bytes);

auto copy_event = memcpy_host2usm(q,
gathered_device_unique.get(),
src_host,
block_count * block_size_in_bytes,
deps);

const byte_t* const gathered_byte =
reinterpret_cast<const byte_t*>(gathered_device_unique.get());
byte_t* const dst_byte = reinterpret_cast<byte_t*>(dst_device);

const auto block_size = propose_block_size<float>(q, block_count);
const bk::uniform_blocking blocking(block_count, block_size);
std::vector<sycl::event> events(blocking.get_block_count());

const auto block_range = blocking.get_block_count();

for (std::int64_t block_index = 0; block_index < block_range; ++block_index) {
const auto start_block = blocking.get_block_start_index(block_index);
const auto end_block = blocking.get_block_end_index(block_index);
const auto curr_block = end_block - start_block;
ONEDAL_ASSERT(curr_block > 0);

auto scatter_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(copy_event);

const std::int64_t required_local_size = bk::device_max_wg_size(q);
const std::int64_t local_size = std::min(down_pow2(curr_block), required_local_size);
const auto range = make_multiple_nd_range_1d(curr_block, local_size);

cgh.parallel_for(range, [=](sycl::nd_item<1> id) {
const auto i = id.get_global_id() + start_block;
if (i < block_count) {
// TODO: Unroll for optimization
for (std::int64_t j = 0; j < block_size_in_bytes; ++j) {
dst_byte[i * dst_stride_in_bytes + j] =
gathered_byte[i * block_size_in_bytes + j];
}
}
});
});
events.push_back(scatter_event);
}
// We need to wait until scatter kernel is completed to deallocate
// `gathered_device_unique`
return bk::wait_or_pass(events);
}

} // namespace oneapi::dal::backend
28 changes: 20 additions & 8 deletions cpp/oneapi/dal/table/backend/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,26 @@ sycl::event convert_vector_host2device(sycl::queue& q,
src_stride,
1L,
element_count);

auto scatter_event = scatter_host2device(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
element_size_in_bytes,
deps);
const std::int64_t max_loop_range = std::numeric_limits<std::int32_t>::max();
sycl::event scatter_event;
if (element_count > max_loop_range) {
scatter_event = scatter_host2device_blocking(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
element_size_in_bytes,
deps);
}
else {
scatter_event = scatter_host2device(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
element_size_in_bytes,
deps);
}
return scatter_event;
}

Expand Down

0 comments on commit 138ddd6

Please sign in to comment.