Skip to content

Commit 824e751

Browse files
[NFCI][SYCL] Keep raw ptr/ref to devices/platforms in context_impl (#19629)
Similar to #19613. Refactoring has started in #18143 #18251
1 parent 521114d commit 824e751

File tree

3 files changed

+71
-68
lines changed

3 files changed

+71
-68
lines changed

sycl/include/sycl/ext/oneapi/experimental/current_device.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ inline namespace _V1 {
1515
namespace ext::oneapi::experimental::this_thread {
1616

1717
namespace detail {
18-
inline sycl::device &get_current_device_ref() {
19-
static thread_local sycl::device current_device{sycl::default_selector_v};
18+
using namespace sycl::detail;
19+
// Underlying `std::shared_ptr<device_impl>`'s lifetime is tied to the
20+
// `global_handler`, so a subsequent `lock()` is expected to be successful when
21+
// used from user app. We still go through `std::weak_ptr` here because our own
22+
// unittests are linked statically against SYCL RT objects and have to implement
23+
// some hacks to emulate the lifetime management done by the `global_handler`.
24+
inline std::weak_ptr<device_impl> &get_current_device_impl() {
25+
static thread_local std::weak_ptr<device_impl> current_device{
26+
getSyclObjImpl(sycl::device{sycl::default_selector_v})};
2027
return current_device;
2128
}
2229
} // namespace detail
@@ -28,15 +35,16 @@ inline sycl::device &get_current_device_ref() {
2835
/// @pre The function is called from a host thread, executing outside of a host
2936
/// task or an asynchronous error handler.
3037
inline sycl::device get_current_device() {
31-
return detail::get_current_device_ref();
38+
return detail::createSyclObjFromImpl<device>(
39+
detail::get_current_device_impl().lock());
3240
}
3341

3442
/// @brief Sets the current default device to `dev` for the calling host thread.
3543
///
3644
/// @pre The function is called from a host thread, executing outside of a host
3745
/// task or an asynchronous error handler.
3846
inline void set_current_device(sycl::device dev) {
39-
detail::get_current_device_ref() = dev;
47+
detail::get_current_device_impl() = detail::getSyclObjImpl(dev);
4048
}
4149

4250
} // namespace ext::oneapi::experimental::this_thread

sycl/source/detail/context_impl.cpp

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,15 @@ namespace sycl {
2929
inline namespace _V1 {
3030
namespace detail {
3131

32-
context_impl::context_impl(const std::vector<sycl::device> Devices,
33-
async_handler AsyncHandler,
32+
context_impl::context_impl(devices_range Devices, async_handler AsyncHandler,
3433
const property_list &PropList, private_tag)
3534
: MOwnedByRuntime(true), MAsyncHandler(std::move(AsyncHandler)),
36-
MDevices(std::move(Devices)), MContext(nullptr),
37-
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())),
38-
MPropList(PropList), MKernelProgramCache(*this),
39-
MSupportBufferLocationByDevices(NotChecked) {
35+
MDevices(Devices.to<std::vector<device_impl *>>()), MContext(nullptr),
36+
MPlatform(MDevices[0]->getPlatformImpl()), MPropList(PropList),
37+
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
4038
verifyProps(PropList);
4139
std::vector<ur_device_handle_t> DeviceIds;
42-
for (const auto &D : MDevices) {
40+
for (device_impl &D : devices_range{MDevices}) {
4341
if (D.has(aspect::ext_oneapi_is_composite)) {
4442
// Component devices are considered to be descendent devices from a
4543
// composite device and therefore context created for a composite
@@ -52,7 +50,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
5250
DeviceIds.push_back(getSyclObjImpl(CD)->getHandleRef());
5351
}
5452

55-
DeviceIds.push_back(getSyclObjImpl(D)->getHandleRef());
53+
DeviceIds.push_back(D.getHandleRef());
5654
}
5755

5856
getAdapter().call<UrApiKind::urContextCreate>(
@@ -61,39 +59,42 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
6159

6260
context_impl::context_impl(ur_context_handle_t UrContext,
6361
async_handler AsyncHandler, adapter_impl &Adapter,
64-
const std::vector<sycl::device> &DeviceList,
65-
bool OwnedByRuntime, private_tag)
62+
devices_range DeviceList, bool OwnedByRuntime,
63+
private_tag)
6664
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(std::move(AsyncHandler)),
67-
MDevices(DeviceList), MContext(UrContext), MPlatform(),
65+
MDevices([&]() {
66+
if (!DeviceList.empty())
67+
return DeviceList.to<std::vector<device_impl *>>();
68+
69+
std::vector<ur_device_handle_t> DeviceIds;
70+
uint32_t DevicesNum = 0;
71+
// TODO catch an exception and put it to list of asynchronous
72+
// exceptions.
73+
Adapter.call<UrApiKind::urContextGetInfo>(
74+
UrContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum),
75+
&DevicesNum, nullptr);
76+
DeviceIds.resize(DevicesNum);
77+
// TODO catch an exception and put it to list of asynchronous
78+
// exceptions.
79+
Adapter.call<UrApiKind::urContextGetInfo>(
80+
UrContext, UR_CONTEXT_INFO_DEVICES,
81+
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);
82+
83+
if (DeviceIds.empty())
84+
throw exception(
85+
make_error_code(errc::invalid),
86+
"No devices in the provided device list and native context.");
87+
88+
platform_impl &Platform =
89+
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
90+
std::vector<device_impl *> Devices;
91+
for (ur_device_handle_t Dev : DeviceIds)
92+
Devices.emplace_back(&Platform.getOrMakeDeviceImpl(Dev));
93+
94+
return Devices;
95+
}()),
96+
MContext(UrContext), MPlatform(MDevices[0]->getPlatformImpl()),
6897
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
69-
if (!MDevices.empty()) {
70-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
71-
} else {
72-
std::vector<ur_device_handle_t> DeviceIds;
73-
uint32_t DevicesNum = 0;
74-
// TODO catch an exception and put it to list of asynchronous exceptions
75-
Adapter.call<UrApiKind::urContextGetInfo>(
76-
MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
77-
nullptr);
78-
DeviceIds.resize(DevicesNum);
79-
// TODO catch an exception and put it to list of asynchronous exceptions
80-
Adapter.call<UrApiKind::urContextGetInfo>(
81-
MContext, UR_CONTEXT_INFO_DEVICES,
82-
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);
83-
84-
if (DeviceIds.empty())
85-
throw exception(
86-
make_error_code(errc::invalid),
87-
"No devices in the provided device list and native context.");
88-
89-
platform_impl &Platform =
90-
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
91-
for (ur_device_handle_t Dev : DeviceIds) {
92-
MDevices.emplace_back(
93-
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(Dev)));
94-
}
95-
MPlatform = Platform.shared_from_this();
96-
}
9798
// TODO catch an exception and put it to list of asynchronous exceptions
9899
// getAdapter() will be the same as the Adapter passed. This should be taken
99100
// care of when creating device object.
@@ -144,12 +145,12 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
144145
this->getAdapter());
145146
}
146147
template <> platform context_impl::get_info<info::context::platform>() const {
147-
return createSyclObjFromImpl<platform>(*MPlatform);
148+
return createSyclObjFromImpl<platform>(MPlatform);
148149
}
149150
template <>
150151
std::vector<sycl::device>
151152
context_impl::get_info<info::context::devices>() const {
152-
return MDevices;
153+
return devices_range{MDevices}.to<std::vector<sycl::device>>();
153154
}
154155
template <>
155156
std::vector<sycl::memory_order>
@@ -219,7 +220,7 @@ context_impl::get_backend_info<info::platform::version>() const {
219220
"the info::platform::version info descriptor can "
220221
"only be queried with an OpenCL backend");
221222
}
222-
return MDevices[0].get_platform().get_info<info::platform::version>();
223+
return MDevices[0]->get_platform().get_info<info::platform::version>();
223224
}
224225
#endif
225226

@@ -271,17 +272,17 @@ KernelProgramCache &context_impl::getKernelProgramCache() const {
271272
}
272273

273274
bool context_impl::hasDevice(const detail::device_impl &Device) const {
274-
for (auto D : MDevices)
275-
if (getSyclObjImpl(D).get() == &Device)
275+
for (device_impl *D : MDevices)
276+
if (D == &Device)
276277
return true;
277278
return false;
278279
}
279280

280281
device_impl *
281282
context_impl::findMatchingDeviceImpl(ur_device_handle_t &DeviceUR) const {
282-
for (device D : MDevices)
283-
if (getSyclObjImpl(D)->getHandleRef() == DeviceUR)
284-
return getSyclObjImpl(D).get();
283+
for (device_impl *D : MDevices)
284+
if (D->getHandleRef() == DeviceUR)
285+
return D;
285286

286287
return nullptr;
287288
}
@@ -301,8 +302,8 @@ bool context_impl::isBufferLocationSupported() const {
301302
return MSupportBufferLocationByDevices == Supported ? true : false;
302303
// Check that devices within context have support of buffer location
303304
MSupportBufferLocationByDevices = Supported;
304-
for (auto &Device : MDevices) {
305-
if (!Device.has_extension("cl_intel_mem_alloc_buffer_location")) {
305+
for (device_impl *Device : MDevices) {
306+
if (!Device->has_extension("cl_intel_mem_alloc_buffer_location")) {
306307
MSupportBufferLocationByDevices = NotSupported;
307308
break;
308309
}

sycl/source/detail/context_impl.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
4747
/// \param DeviceList is a list of SYCL device instances.
4848
/// \param AsyncHandler is an instance of async_handler.
4949
/// \param PropList is an instance of property_list.
50-
context_impl(const std::vector<sycl::device> DeviceList,
51-
async_handler AsyncHandler, const property_list &PropList,
52-
private_tag);
50+
context_impl(devices_range DeviceList, async_handler AsyncHandler,
51+
const property_list &PropList, private_tag);
5352

5453
/// Construct a context_impl using plug-in interoperability handle.
5554
///
@@ -62,9 +61,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
6261
/// \param OwnedByRuntime is the flag if ownership is kept by user or
6362
/// transferred to runtime
6463
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
65-
adapter_impl &Adapter,
66-
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
67-
private_tag);
64+
adapter_impl &Adapter, devices_range DeviceList,
65+
bool OwnedByRuntime, private_tag);
6866

6967
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
7068
adapter_impl &Adapter, private_tag tag)
@@ -94,10 +92,10 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
9492
const async_handler &get_async_handler() const;
9593

9694
/// \return the Adapter associated with the platform of this context.
97-
adapter_impl &getAdapter() const { return MPlatform->getAdapter(); }
95+
adapter_impl &getAdapter() const { return MPlatform.getAdapter(); }
9896

9997
/// \return the PlatformImpl associated with this context.
100-
platform_impl &getPlatformImpl() const { return *MPlatform; }
98+
platform_impl &getPlatformImpl() const { return MPlatform; }
10199

102100
/// Queries this context for information.
103101
///
@@ -191,10 +189,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
191189
}
192190

193191
// Returns the backend of this context
194-
backend getBackend() const {
195-
assert(MPlatform && "MPlatform must be not null");
196-
return MPlatform->getBackend();
197-
}
192+
backend getBackend() const { return MPlatform.getBackend(); }
198193

199194
/// Given a UR device, returns the matching shared_ptr<device_impl>
200195
/// within this context. May return nullptr if no match discovered.
@@ -262,10 +257,9 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
262257
private:
263258
bool MOwnedByRuntime;
264259
async_handler MAsyncHandler;
265-
std::vector<device> MDevices;
260+
std::vector<device_impl *> MDevices;
266261
ur_context_handle_t MContext;
267-
// TODO: Make it a reference instead, but that needs a bit more refactoring:
268-
std::shared_ptr<platform_impl> MPlatform;
262+
platform_impl &MPlatform;
269263
property_list MPropList;
270264
CachedLibProgramsT MCachedLibPrograms;
271265
std::mutex MCachedLibProgramsMutex;

0 commit comments

Comments
 (0)