Skip to content

Commit

Permalink
ComputeCpp local memory support through static SYCL handler hack
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Jan 18, 2022
1 parent b867598 commit 8e2fce4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 39 deletions.
6 changes: 1 addition & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,7 @@ else()
set(CELERITY_FEATURE_SIMPLE_SCALAR_REDUCTIONS OFF)
endif()

if(CELERITY_SYCL_IMPL STREQUAL hipSYCL OR CELERITY_SYCL_IMPL STREQUAL "DPC++")
set(CELERITY_FEATURE_LOCAL_ACCESSOR ON)
else()
set(CELERITY_FEATURE_LOCAL_ACCESSOR OFF)
endif()
set(CELERITY_FEATURE_LOCAL_ACCESSOR ON)

if(NOT CELERITY_SYCL_IMPL STREQUAL ComputeCpp)
set(CELERITY_FEATURE_UNNAMED_KERNELS ON)
Expand Down
16 changes: 0 additions & 16 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ void multiply(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a, celerit
celerity::accessor b{mat_b, cgh, celerity::access::slice<2>(0), celerity::read_only};
celerity::accessor c{mat_c, cgh, celerity::access::one_to_one{}, celerity::write_only, celerity::no_init};

#if CELERITY_FEATURE_LOCAL_ACCESSOR

// Use local-memory tiling to avoid waiting on global memory too often
const size_t GROUP_SIZE = 8;
celerity::local_accessor<T, 2> scratch_a{{GROUP_SIZE, GROUP_SIZE}, cgh};
Expand All @@ -43,20 +41,6 @@ void multiply(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a, celerit
}
c[item.get_global_id()] = sum;
});

#else

cgh.parallel_for<class mat_mul>(celerity::range<2>(MAT_SIZE, MAT_SIZE), [=](celerity::item<2> item) {
T sum{};
for(size_t k = 0; k < MAT_SIZE; ++k) {
const auto a_ik = a[{item[0], k}];
const auto b_kj = b[{k, item[1]}];
sum += a_ik * b_kj;
}
c[item] = sum;
});

#endif
});
}

Expand Down
42 changes: 26 additions & 16 deletions include/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ namespace detail {
return *hack_make_invisible_nullptr<T>();
}

#if WORKAROUND_COMPUTECPP
class hack_null_sycl_handler: public sycl::handler {
public:
hack_null_sycl_handler(): sycl::handler(nullptr) {}
};
#endif

} // namespace detail

/**
Expand Down Expand Up @@ -550,38 +557,30 @@ class accessor<DataT, Dims, Mode, target::host_task> : public detail::accessor_b

template <typename DataT, int Dims = 1>
class local_accessor {
#if !CELERITY_FEATURE_LOCAL_ACCESSOR
static_assert(detail::constexpr_false<DataT>, "Your SYCL implementation cannot support celerity::local_accessor");
#else
private:
#if WORKAROUND_DPCPP
#if WORKAROUND_DPCPP || WORKAROUND(COMPUTECPP, 2, 6)
using sycl_accessor = cl::sycl::accessor<DataT, Dims, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>;
#else
using sycl_accessor = cl::sycl::local_accessor<DataT, Dims>;
#endif

template <typename Index>
using subscript_type = decltype(std::declval<sycl_accessor>()[std::declval<const Index &>()]);

public:
using value_type = DataT;
using reference = DataT&;
using const_reference = const DataT&;
using size_type = size_t;

local_accessor()
#if WORKAROUND_DPCPP
: sycl_acc(allocation_size, detail::hack_make_invisible_null_reference<cl::sycl::handler>()),
#else
: sycl_acc(),
#endif
: sycl_acc{make_dangling_sycl_accessor()},
allocation_size(detail::zero_range) {
}

#if !defined(__SYCL_DEVICE_ONLY__) && !defined(SYCL_DEVICE_ONLY)
local_accessor(const range<Dims>& allocation_size, handler& cgh)
#if WORKAROUND_DPCPP
: sycl_acc(allocation_size, detail::hack_make_invisible_null_reference<cl::sycl::handler>()),
#else
: sycl_acc(),
#endif
: sycl_acc{make_dangling_sycl_accessor()},
allocation_size(allocation_size) {
if(!detail::is_prepass_handler(cgh)) {
auto& device_handler = dynamic_cast<detail::live_pass_device_handler&>(cgh);
Expand Down Expand Up @@ -612,7 +611,7 @@ class local_accessor {
std::add_pointer_t<value_type> get_pointer() const noexcept { return sycl_acc.get_pointer(); }

template <typename Index>
inline decltype(auto) operator[](const Index& index) const {
inline subscript_type<Index> operator[](const Index& index) const {
return sycl_acc[index];
}

Expand All @@ -621,8 +620,19 @@ class local_accessor {
range<Dims> allocation_size;
cl::sycl::handler* const* eventual_sycl_cgh = nullptr;

cl::sycl::handler* sycl_cgh() const { return eventual_sycl_cgh != nullptr ? *eventual_sycl_cgh : nullptr; }
static sycl_accessor make_dangling_sycl_accessor()
{
#if WORKAROUND_DPCPP
return sycl_accessor{detail::zero_range, detail::hack_make_invisible_null_reference<cl::sycl::handler>()};
#elif WORKAROUND_COMPUTECPP
detail::hack_null_sycl_handler null_cgh;
return sycl_accessor{detail::zero_range, null_cgh};
#else
return sycl_accessor{};
#endif
}

cl::sycl::handler* sycl_cgh() const { return eventual_sycl_cgh != nullptr ? *eventual_sycl_cgh : nullptr; }
};


Expand Down
2 changes: 0 additions & 2 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,6 @@ namespace detail {
CHECK_THROWS_WITH((celerity::nd_range<3>{{256, 256, 256}, {2, 1, 0}}), "global_range is not divisible by local_range");
}

#if CELERITY_FEATURE_LOCAL_ACCESSOR
TEST_CASE("nd_range kernels support local memory", "[handler]") {
distr_queue q;
buffer<int, 1> out{64};
Expand All @@ -2219,7 +2218,6 @@ namespace detail {
});
});
}
#endif

#if CELERITY_FEATURE_SIMPLE_SCALAR_REDUCTIONS

Expand Down

0 comments on commit 8e2fce4

Please sign in to comment.