Skip to content

[NFCI][SYCL] Keep raw ptr/ref to devices/platforms in context_impl #19629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions sycl/include/sycl/ext/oneapi/experimental/current_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ inline namespace _V1 {
namespace ext::oneapi::experimental::this_thread {

namespace detail {
inline sycl::device &get_current_device_ref() {
static thread_local sycl::device current_device{sycl::default_selector_v};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that could extend sycl::device's lifetime past global_handler destruction when running our unittests, but that device wasn't really usable because we've already released all the UR resources at that point.

using namespace sycl::detail;
// Underlying `std::shared_ptr<device_impl>`'s lifetime is tied to the
// `global_handler`, so a subsequent `lock()` is expected to be successful when
// used from user app. We still go through `std::weak_ptr` here because our own
// unittests are linked statically against SYCL RT objects and have to implement
// some hacks to emulate the lifetime management done by the `global_handler`.
inline std::weak_ptr<device_impl> &get_current_device_impl() {
static thread_local std::weak_ptr<device_impl> current_device{
getSyclObjImpl(sycl::device{sycl::default_selector_v})};
return current_device;
}
} // namespace detail
Expand All @@ -28,15 +35,16 @@ inline sycl::device &get_current_device_ref() {
/// @pre The function is called from a host thread, executing outside of a host
/// task or an asynchronous error handler.
inline sycl::device get_current_device() {
return detail::get_current_device_ref();
return detail::createSyclObjFromImpl<device>(
detail::get_current_device_impl().lock());
}

/// @brief Sets the current default device to `dev` for the calling host thread.
///
/// @pre The function is called from a host thread, executing outside of a host
/// task or an asynchronous error handler.
inline void set_current_device(sycl::device dev) {
detail::get_current_device_ref() = dev;
detail::get_current_device_impl() = detail::getSyclObjImpl(dev);
}

} // namespace ext::oneapi::experimental::this_thread
Expand Down
99 changes: 50 additions & 49 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

context_impl::context_impl(const std::vector<sycl::device> Devices,
async_handler AsyncHandler,
context_impl::context_impl(devices_range Devices, async_handler AsyncHandler,
const property_list &PropList, private_tag)
: MOwnedByRuntime(true), MAsyncHandler(std::move(AsyncHandler)),
MDevices(std::move(Devices)), MContext(nullptr),
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())),
MPropList(PropList), MKernelProgramCache(*this),
MSupportBufferLocationByDevices(NotChecked) {
MDevices(Devices.to<std::vector<device_impl *>>()), MContext(nullptr),
MPlatform(MDevices[0]->getPlatformImpl()), MPropList(PropList),
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
verifyProps(PropList);
std::vector<ur_device_handle_t> DeviceIds;
for (const auto &D : MDevices) {
for (device_impl &D : devices_range{MDevices}) {
if (D.has(aspect::ext_oneapi_is_composite)) {
// Component devices are considered to be descendent devices from a
// composite device and therefore context created for a composite
Expand All @@ -52,7 +50,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
DeviceIds.push_back(getSyclObjImpl(CD)->getHandleRef());
}

DeviceIds.push_back(getSyclObjImpl(D)->getHandleRef());
DeviceIds.push_back(D.getHandleRef());
}

getAdapter().call<UrApiKind::urContextCreate>(
Expand All @@ -61,39 +59,42 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,

context_impl::context_impl(ur_context_handle_t UrContext,
async_handler AsyncHandler, adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList,
bool OwnedByRuntime, private_tag)
devices_range DeviceList, bool OwnedByRuntime,
private_tag)
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(std::move(AsyncHandler)),
MDevices(DeviceList), MContext(UrContext), MPlatform(),
MDevices([&]() {
if (!DeviceList.empty())
return DeviceList.to<std::vector<device_impl *>>();

std::vector<ur_device_handle_t> DeviceIds;
uint32_t DevicesNum = 0;
// TODO catch an exception and put it to list of asynchronous
// exceptions.
Adapter.call<UrApiKind::urContextGetInfo>(
UrContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum),
&DevicesNum, nullptr);
DeviceIds.resize(DevicesNum);
// TODO catch an exception and put it to list of asynchronous
// exceptions.
Adapter.call<UrApiKind::urContextGetInfo>(
UrContext, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);

if (DeviceIds.empty())
throw exception(
make_error_code(errc::invalid),
"No devices in the provided device list and native context.");

platform_impl &Platform =
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
std::vector<device_impl *> Devices;
for (ur_device_handle_t Dev : DeviceIds)
Devices.emplace_back(&Platform.getOrMakeDeviceImpl(Dev));

return Devices;
}()),
MContext(UrContext), MPlatform(MDevices[0]->getPlatformImpl()),
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
if (!MDevices.empty()) {
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
} else {
std::vector<ur_device_handle_t> DeviceIds;
uint32_t DevicesNum = 0;
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter.call<UrApiKind::urContextGetInfo>(
MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
nullptr);
DeviceIds.resize(DevicesNum);
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter.call<UrApiKind::urContextGetInfo>(
MContext, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);

if (DeviceIds.empty())
throw exception(
make_error_code(errc::invalid),
"No devices in the provided device list and native context.");

platform_impl &Platform =
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
for (ur_device_handle_t Dev : DeviceIds) {
MDevices.emplace_back(
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(Dev)));
}
MPlatform = Platform.shared_from_this();
}
// TODO catch an exception and put it to list of asynchronous exceptions
// getAdapter() will be the same as the Adapter passed. This should be taken
// care of when creating device object.
Expand Down Expand Up @@ -144,12 +145,12 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
this->getAdapter());
}
template <> platform context_impl::get_info<info::context::platform>() const {
return createSyclObjFromImpl<platform>(*MPlatform);
return createSyclObjFromImpl<platform>(MPlatform);
}
template <>
std::vector<sycl::device>
context_impl::get_info<info::context::devices>() const {
return MDevices;
return devices_range{MDevices}.to<std::vector<sycl::device>>();
}
template <>
std::vector<sycl::memory_order>
Expand Down Expand Up @@ -219,7 +220,7 @@ context_impl::get_backend_info<info::platform::version>() const {
"the info::platform::version info descriptor can "
"only be queried with an OpenCL backend");
}
return MDevices[0].get_platform().get_info<info::platform::version>();
return MDevices[0]->get_platform().get_info<info::platform::version>();
}
#endif

Expand Down Expand Up @@ -271,17 +272,17 @@ KernelProgramCache &context_impl::getKernelProgramCache() const {
}

bool context_impl::hasDevice(const detail::device_impl &Device) const {
for (auto D : MDevices)
if (getSyclObjImpl(D).get() == &Device)
for (device_impl *D : MDevices)
if (D == &Device)
return true;
return false;
}

device_impl *
context_impl::findMatchingDeviceImpl(ur_device_handle_t &DeviceUR) const {
for (device D : MDevices)
if (getSyclObjImpl(D)->getHandleRef() == DeviceUR)
return getSyclObjImpl(D).get();
for (device_impl *D : MDevices)
if (D->getHandleRef() == DeviceUR)
return D;

return nullptr;
}
Expand All @@ -301,8 +302,8 @@ bool context_impl::isBufferLocationSupported() const {
return MSupportBufferLocationByDevices == Supported ? true : false;
// Check that devices within context have support of buffer location
MSupportBufferLocationByDevices = Supported;
for (auto &Device : MDevices) {
if (!Device.has_extension("cl_intel_mem_alloc_buffer_location")) {
for (device_impl *Device : MDevices) {
if (!Device->has_extension("cl_intel_mem_alloc_buffer_location")) {
MSupportBufferLocationByDevices = NotSupported;
break;
}
Expand Down
24 changes: 9 additions & 15 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
/// \param DeviceList is a list of SYCL device instances.
/// \param AsyncHandler is an instance of async_handler.
/// \param PropList is an instance of property_list.
context_impl(const std::vector<sycl::device> DeviceList,
async_handler AsyncHandler, const property_list &PropList,
private_tag);
context_impl(devices_range DeviceList, async_handler AsyncHandler,
const property_list &PropList, private_tag);

/// Construct a context_impl using plug-in interoperability handle.
///
Expand All @@ -62,9 +61,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
/// \param OwnedByRuntime is the flag if ownership is kept by user or
/// transferred to runtime
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
private_tag);
adapter_impl &Adapter, devices_range DeviceList,
bool OwnedByRuntime, private_tag);

context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
adapter_impl &Adapter, private_tag tag)
Expand Down Expand Up @@ -94,10 +92,10 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
const async_handler &get_async_handler() const;

/// \return the Adapter associated with the platform of this context.
adapter_impl &getAdapter() const { return MPlatform->getAdapter(); }
adapter_impl &getAdapter() const { return MPlatform.getAdapter(); }

/// \return the PlatformImpl associated with this context.
platform_impl &getPlatformImpl() const { return *MPlatform; }
platform_impl &getPlatformImpl() const { return MPlatform; }

/// Queries this context for information.
///
Expand Down Expand Up @@ -191,10 +189,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
}

// Returns the backend of this context
backend getBackend() const {
assert(MPlatform && "MPlatform must be not null");
return MPlatform->getBackend();
}
backend getBackend() const { return MPlatform.getBackend(); }

/// Given a UR device, returns the matching shared_ptr<device_impl>
/// within this context. May return nullptr if no match discovered.
Expand Down Expand Up @@ -262,10 +257,9 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
private:
bool MOwnedByRuntime;
async_handler MAsyncHandler;
std::vector<device> MDevices;
std::vector<device_impl *> MDevices;
ur_context_handle_t MContext;
// TODO: Make it a reference instead, but that needs a bit more refactoring:
std::shared_ptr<platform_impl> MPlatform;
platform_impl &MPlatform;
property_list MPropList;
CachedLibProgramsT MCachedLibPrograms;
std::mutex MCachedLibProgramsMutex;
Expand Down