Skip to content

Commit

Permalink
[SYCL] [USM] Implement prefetch.
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Cranmer <joshua.cranmer@intel.com>
  • Loading branch information
jcranmer-intel authored and romanovvlad committed Oct 8, 2019
1 parent 866d634 commit feeacc1
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 1 deletion.
23 changes: 22 additions & 1 deletion sycl/include/CL/sycl/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ class CG {
UPDATE_HOST,
RUN_ON_HOST_INTEL,
COPY_USM,
FILL_USM
FILL_USM,
PREFETCH_USM
};

CG(CGTYPE Type, std::vector<std::vector<char>> ArgsStorage,
Expand Down Expand Up @@ -529,6 +530,26 @@ class CGFillUSM : public CG {
int getFill() { return MPattern[0]; }
};

// The class which represents "prefetch" command group for USM pointers.
class CGPrefetchUSM : public CG {
void *MDst;
size_t MLength;

public:
CGPrefetchUSM(void *DstPtr, size_t Length,
std::vector<std::vector<char>> ArgsStorage,
std::vector<detail::AccessorImplPtr> AccStorage,
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
std::vector<Requirement *> Requirements,
std::vector<detail::EventImplPtr> Events)
: CG(PREFETCH_USM, std::move(ArgsStorage), std::move(AccStorage),
std::move(SharedPtrStorage), std::move(Requirements),
std::move(Events)),
MDst(DstPtr), MLength(Length) {}
void *getDst() { return MDst; }
size_t getLength() { return MLength; }
};

} // namespace detail
} // namespace sycl
} // namespace cl
4 changes: 4 additions & 0 deletions sycl/include/CL/sycl/detail/memory_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class MemoryManager {
int Pattern, std::vector<RT::PiEvent> DepEvents,
RT::PiEvent &OutEvent);

static void prefetch_usm(void *Ptr, QueueImplPtr Queue, size_t Len,
std::vector<RT::PiEvent> DepEvents,
RT::PiEvent &OutEvent);

};
} // namespace detail
} // namespace sycl
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/CL/sycl/detail/usm_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class USMDispatcher {
void *ParamValue, size_t *ParamValueSizeRet);
void memAdvise(pi_queue Queue, const void *Ptr, size_t Length, int Advice,
pi_event *Event);
pi_result enqueuePrefetch(pi_queue Queue, void *Ptr, size_t Size,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event);

private:
bool mEmulated = false;
Expand Down
13 changes: 13 additions & 0 deletions sycl/include/CL/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ class handler {
std::move(MAccStorage), std::move(MSharedPtrStorage),
std::move(MRequirements), std::move(MEvents)));
break;
case detail::CG::PREFETCH_USM:
CommandGroup.reset(new detail::CGPrefetchUSM(
MDstPtr, MLength, std::move(MArgsStorage),
std::move(MAccStorage), std::move(MSharedPtrStorage),
std::move(MRequirements), std::move(MEvents)));
break;
case detail::CG::NONE:
throw runtime_error("Command group submitted without a kernel or a "
"explicit memory operation.");
Expand Down Expand Up @@ -1163,6 +1169,13 @@ class handler {
MLength = Count;
MCGType = detail::CG::FILL_USM;
}

// Prefetch the memory pointed to by the pointer.
void prefetch(const void *Ptr, size_t Count) {
MDstPtr = const_cast<void *>(Ptr);
MLength = Count;
MCGType = detail::CG::PREFETCH_USM;
}
};
} // namespace sycl
} // namespace cl
6 changes: 6 additions & 0 deletions sycl/include/CL/sycl/ordered_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ class ordered_queue {
return impl->memcpy(impl, dest, src, count);
}

event prefetch(const void* Ptr, size_t Count) {
return submit([=](handler &cgh) {
cgh.prefetch(Ptr, Count);
});
}

private:
std::shared_ptr<detail::queue_impl> impl;
template <class Obj>
Expand Down
6 changes: 6 additions & 0 deletions sycl/include/CL/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ class queue {
return impl->mem_advise(Ptr, Length, Advice);
}

event prefetch(const void* Ptr, size_t Count) {
return submit([=](handler &cgh) {
cgh.prefetch(Ptr, Count);
});
}

private:
std::shared_ptr<detail::queue_impl> impl;
template <class Obj>
Expand Down
15 changes: 15 additions & 0 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,21 @@ void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
}
}

void MemoryManager::prefetch_usm(void *Mem, QueueImplPtr Queue, size_t Length,
std::vector<RT::PiEvent> DepEvents,
RT::PiEvent &OutEvent) {
sycl::context Context = Queue->get_context();

if (Context.is_host()) {
// TODO: Potentially implement prefetch on the host.
} else {
std::shared_ptr<usm::USMDispatcher> USMDispatch =
getSyclObjImpl(Context)->getUSMDispatch();
PI_CHECK(USMDispatch->enqueuePrefetch(Queue->getHandleRef(),
Mem, Length, DepEvents.size(), &DepEvents[0], &OutEvent));
}
}

} // namespace detail
} // namespace sycl
} // namespace cl
9 changes: 9 additions & 0 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
case detail::CG::FILL_USM:
Stream << "CG type: fill usm\\n";
break;
case detail::CG::PREFETCH_USM:
Stream << "CG type: prefetch usm\\n";
break;
default:
Stream << "CG type: unknown\\n";
break;
Expand Down Expand Up @@ -785,6 +788,12 @@ cl_int ExecCGCommand::enqueueImp() {
Fill->getFill(), std::move(RawEvents), Event);
return CL_SUCCESS;
}
case CG::CGTYPE::PREFETCH_USM: {
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
MemoryManager::prefetch_usm(Prefetch->getDst(), MQueue,
Prefetch->getLength(), std::move(RawEvents), Event);
return CL_SUCCESS;
}
case CG::CGTYPE::NONE:
default:
throw runtime_error("CG type not implemented.");
Expand Down
23 changes: 23 additions & 0 deletions sycl/source/detail/usm/usm_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,29 @@ void USMDispatcher::memAdvise(pi_queue Queue, const void *Ptr, size_t Length,
}
}
}

pi_result USMDispatcher::enqueuePrefetch(pi_queue Queue, void *Ptr, size_t Size,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList,
pi_event *Event) {
pi_result RetVal = PI_INVALID_OPERATION;

if (pi::useBackend(pi::Backend::SYCL_BE_PI_OPENCL)) {
if (mEmulated) {
// Prefetch is a hint, so ignoring it is always safe.
RetVal = PI_CALL_RESULT(RT::piEnqueueEventsWait(
Queue, NumEventsInWaitList, EventWaitList, Event));
} else {
// TODO: Replace this with real prefetch support when the driver enables
// it.
RetVal = PI_CALL_RESULT(RT::piEnqueueEventsWait(
Queue, NumEventsInWaitList, EventWaitList, Event));
}
}

return RetVal;
}

} // namespace usm
} // namespace detail
} // namespace sycl
Expand Down
89 changes: 89 additions & 0 deletions sycl/test/usm/prefetch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//==---- prefetch.cpp - USM prefetch test ----------------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: %clangxx -fsycl %s -o %t1.out -lOpenCL
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
// RUN: %CPU_RUN_PLACEHOLDER %t1.out

#include <CL/sycl.hpp>

using namespace cl::sycl;

static constexpr int count = 100;

int main() {
queue q([](exception_list el) {
for (auto &e : el)
throw e;
});
float *src = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
q.get_context());
float *dest = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
q.get_context());
for (int i = 0; i < count; i++)
src[i] = i;

// Test handler::prefetch
{
event init_prefetch = q.submit([&](handler &cgh) {
cgh.prefetch(src, sizeof(float) * count);
});

q.submit([&](handler &cgh) {
cgh.depends_on(init_prefetch);
cgh.single_task<class double_dest>([=]() {
for (int i = 0; i < count; i++)
dest[i] = 2 * src[i];
});
});
q.wait_and_throw();

for (int i = 0; i < count; i++) {
assert(dest[i] == i * 2);
}
}

// Test queue::prefetch
{
event init_prefetch = q.prefetch(src, sizeof(float) * count);

q.submit([&](handler &cgh) {
cgh.depends_on(init_prefetch);
cgh.single_task<class double_dest3>([=]() {
for (int i = 0; i < count; i++)
dest[i] = 3 * src[i];
});
});
q.wait_and_throw();

for (int i = 0; i < count; i++) {
assert(dest[i] == i * 3);
}
}

// Test ordered_queue::prefetch
{
ordered_queue oq([](exception_list el) {
for (auto &e : el)
throw e;
});
event init_prefetch = oq.prefetch(src, sizeof(float) * count);

oq.submit([&](handler &cgh) {
cgh.depends_on(init_prefetch);
cgh.single_task<class double_dest4>([=]() {
for (int i = 0; i < count; i++)
dest[i] = 4 * src[i];
});
});
oq.wait_and_throw();

for (int i = 0; i < count; i++) {
assert(dest[i] == i * 4);
}
}
}

0 comments on commit feeacc1

Please sign in to comment.