Skip to content

Commit

Permalink
[SYCL] Initial support for virtual functions in runtime (#14382)
Browse files Browse the repository at this point in the history
Implementation design can be found in #10540.

Main responsibility of the runtime is to gather device images dependent
on each other and link them together, carefully updating caches so that
when we launch a kernel from a group using the same virtual functions,
we re-use a program that we linked just once.

Missing parts:
- handling of kernel bundles
- handling of optional kernel features (selection of "dummy" device
images when virtual function uses unsupported features)
  • Loading branch information
AlexeySachkov authored Jul 10, 2024
1 parent 7221b17 commit c7d018f
Show file tree
Hide file tree
Showing 10 changed files with 576 additions and 42 deletions.
2 changes: 2 additions & 0 deletions sycl/include/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,8 @@ static const uint8_t PI_DEVICE_BINARY_OFFLOAD_KIND_SYCL = 4;
"SYCL/device requirements"
/// PropertySetRegistry::SYCL_HOST_PIPES defined in PropertySetIO.h
#define __SYCL_PI_PROPERTY_SET_SYCL_HOST_PIPES "SYCL/host pipes"
/// PropertySetRegistry::SYCL_VIRTUAL_FUNCTIONS defined in PropertySetIO.h
#define __SYCL_PI_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS "SYCL/virtual functions"

/// Program metadata tags recognized by the PI backends. For kernels the tag
/// must appear after the kernel name.
Expand Down
9 changes: 7 additions & 2 deletions sycl/source/detail/device_binary_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ ByteArray DeviceBinaryProperty::asByteArray() const {
}

const char *DeviceBinaryProperty::asCString() const {
assert(Prop->Type == PI_PROPERTY_TYPE_STRING && "property type mismatch");
assert((Prop->Type == PI_PROPERTY_TYPE_STRING ||
Prop->Type == PI_PROPERTY_TYPE_BYTE_ARRAY) &&
"property type mismatch");
assert(Prop->ValSize > 0 && "property size mismatch");
return pi::cast<const char *>(Prop->ValAddr);
// Byte array stores its size in first 8 bytes
size_t Shift = Prop->Type == PI_PROPERTY_TYPE_BYTE_ARRAY ? 8 : 0;
return pi::cast<const char *>(Prop->ValAddr) + Shift;
}

void RTDeviceBinaryImage::PropertyRange::init(pi_device_binary Bin,
Expand Down Expand Up @@ -177,6 +181,7 @@ void RTDeviceBinaryImage::init(pi_device_binary Bin) {
DeviceGlobals.init(Bin, __SYCL_PI_PROPERTY_SET_SYCL_DEVICE_GLOBALS);
DeviceRequirements.init(Bin, __SYCL_PI_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
HostPipes.init(Bin, __SYCL_PI_PROPERTY_SET_SYCL_HOST_PIPES);
VirtualFunctions.init(Bin, __SYCL_PI_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);

ImageId = ImageCounter++;
}
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class RTDeviceBinaryImage {
return DeviceRequirements;
}
const PropertyRange &getHostPipes() const { return HostPipes; }
const PropertyRange &getVirtualFunctions() const { return VirtualFunctions; }

std::uintptr_t getImageID() const {
assert(Bin && "Image ID is not available without a binary image.");
Expand All @@ -242,6 +243,7 @@ class RTDeviceBinaryImage {
RTDeviceBinaryImage::PropertyRange DeviceGlobals;
RTDeviceBinaryImage::PropertyRange DeviceRequirements;
RTDeviceBinaryImage::PropertyRange HostPipes;
RTDeviceBinaryImage::PropertyRange VirtualFunctions;

private:
static std::atomic<uintptr_t> ImageCounter;
Expand Down
27 changes: 27 additions & 0 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class KernelProgramCache {
ProgramBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
Val = nullptr;
}
ProgramBuildResult(const PluginPtr &Plugin, BuildState InitialState)
: Plugin(Plugin) {
Val = nullptr;
this->State.store(InitialState);
}
~ProgramBuildResult() {
if (Val) {
sycl::detail::pi::PiResult Err =
Expand Down Expand Up @@ -184,6 +189,28 @@ class KernelProgramCache {
return std::make_pair(It->second, DidInsert);
}

// Used in situation where you have several cache keys corresponding to the
// same program. An example would be a multi-device build, or use of virtual
// functions in kernels.
//
// Returns whether or not an insertion took place.
bool insertBuiltProgram(const ProgramCacheKeyT &CacheKey,
sycl::detail::pi::PiProgram Program) {
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
if (DidInsert) {
It->second = std::make_shared<ProgramBuildResult>(getPlugin(),
BuildState::BS_Done);
It->second->Val = Program;
// Save reference between the common key and the full key.
CommonProgramKeyT CommonKey =
std::make_pair(CacheKey.first.second, CacheKey.second);
ProgCache.KeyMap.emplace(CommonKey, CacheKey);
}
return DidInsert;
}

std::pair<KernelBuildResultPtr, bool>
getOrInsertKernel(sycl::detail::pi::PiProgram Program,
const std::string &KernelName) {
Expand Down
209 changes: 174 additions & 35 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <detail/program_manager/program_manager.hpp>
#include <detail/queue_impl.hpp>
#include <detail/spec_constant_impl.hpp>
#include <detail/split_string.hpp>
#include <sycl/aspects.hpp>
#include <sycl/backend_types.hpp>
#include <sycl/context.hpp>
Expand Down Expand Up @@ -517,6 +518,98 @@ static void emitBuiltProgramInfo(const pi_program &Prog,
}
}

std::set<RTDeviceBinaryImage *>
ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
const RTDeviceBinaryImage &Img, device Dev) {
// If virtual functions are used in a program, then we need to link several
// device images together to make sure that vtable pointers stored in
// objects are valid between different kernels (which could be in different
// device images).
std::set<RTDeviceBinaryImage *> DeviceImagesToLink;
// KernelA may use some set-a, which is also used by KernelB that in turn
// uses set-b, meaning that this search should be recursive. The set below
// is used to stop that recursion, i.e. to avoid looking at sets we have
// already seen.
std::set<std::string> HandledSets;
std::queue<std::string> WorkList;
for (const pi_device_binary_property &VFProp : Img.getVirtualFunctions()) {
std::string StrValue = DeviceBinaryProperty(VFProp).asCString();
// Device image passed to this function is expected to contain SYCL kernels
// and therefore it may only use virtual function sets, but cannot provide
// them. We expect to see just a single property here
assert(std::string(VFProp->Name) == "uses-virtual-functions-set" &&
"Unexpected virtual function property");
for (const auto &SetName : detail::split_string(StrValue, ',')) {
WorkList.push(SetName);
HandledSets.insert(SetName);
}
}

while (!WorkList.empty()) {
std::string SetName = WorkList.front();
WorkList.pop();

// There could be more than one device image that uses the same set
// of virtual functions, or provides virtual funtions from the same
// set.
for (RTDeviceBinaryImage *BinImage : m_VFSet2BinImage[SetName]) {
// Here we can encounter both uses-virtual-functions-set and
// virtual-functions-set properties, but their handling is the same: we
// just grab all sets they reference and add them for consideration if
// we haven't done so already.
for (const pi_device_binary_property &VFProp :
BinImage->getVirtualFunctions()) {
std::string StrValue = DeviceBinaryProperty(VFProp).asCString();
for (const auto &SetName : detail::split_string(StrValue, ',')) {
if (HandledSets.insert(SetName).second)
WorkList.push(SetName);
}
}

// TODO: Complete this part about handling of incompatible device images.
// If device image uses the same virtual function set, then we only
// link it if it is compatible.
// However, if device image provides virtual function set and it is
// incompatible, then we should link its "dummy" version to avoid link
// errors about unresolved external symbols.
if (doesDevSupportDeviceRequirements(Dev, *BinImage))
DeviceImagesToLink.insert(BinImage);
}
}

// We may have inserted the original image into the list as well, because it
// is also a part of m_VFSet2BinImage map. No need to to return it to avoid
// passing it twice to link call later.
DeviceImagesToLink.erase(const_cast<RTDeviceBinaryImage *>(&Img));

return DeviceImagesToLink;
}

static void
setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
sycl::detail::pi::PiProgram Prog,
const PluginPtr &Plugin) {
// Set ITT annotation specialization constant if needed.
enableITTAnnotationsIfNeeded(Prog, Plugin);

std::lock_guard<std::mutex> Lock{InputImpl->get_spec_const_data_lock()};
const std::map<std::string, std::vector<device_image_impl::SpecConstDescT>>
&SpecConstData = InputImpl->get_spec_const_data_ref();
const SerializedObj &SpecConsts = InputImpl->get_spec_const_blob_ref();

// Set all specialization IDs from descriptors in the input device image.
for (const auto &[SpecConstNames, SpecConstDescs] : SpecConstData) {
std::ignore = SpecConstNames;
for (const device_image_impl::SpecConstDescT &SpecIDDesc : SpecConstDescs) {
if (SpecIDDesc.IsSet) {
Plugin->call<PiApiKind::piextProgramSetSpecializationConstant>(
Prog, SpecIDDesc.ID, SpecIDDesc.Size,
SpecConsts.data() + SpecIDDesc.BlobOffset);
}
}
}
}

// When caching is enabled, the returned PiProgram will already have
// its ref count incremented.
sycl::detail::pi::PiProgram ProgramManager::getBuiltPIProgram(
Expand Down Expand Up @@ -560,8 +653,10 @@ sycl::detail::pi::PiProgram ProgramManager::getBuiltPIProgram(
if (auto exception = checkDevSupportDeviceRequirements(Device, Img, NDRDesc))
throw *exception;

std::set<RTDeviceBinaryImage *> DeviceImagesToLink =
collectDependentDeviceImagesForVirtualFunctions(Img, Device);
auto BuildF = [this, &Img, &Context, &ContextImpl, &Device, &CompileOpts,
&LinkOpts, SpecConsts] {
&LinkOpts, SpecConsts, &DeviceImagesToLink] {
const PluginPtr &Plugin = ContextImpl->getPlugin();
applyOptionsFromImage(CompileOpts, LinkOpts, Img, {Device}, Plugin);
// Should always come last!
Expand Down Expand Up @@ -590,9 +685,39 @@ sycl::detail::pi::PiProgram ProgramManager::getBuiltPIProgram(
!SYCLConfig<SYCL_DEVICELIB_NO_FALLBACK>::get())
DeviceLibReqMask = getDeviceLibReqMask(Img);

std::vector<sycl::detail::pi::PiProgram> ProgramsToLink;
// If we had a program in cache, then it should have been the fully linked
// program already.
if (!DeviceCodeWasInCache) {
for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink) {
device_image_plain DevImagePlain =
getDeviceImageFromBinaryImage(BinImg, Context, Device);
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
detail::getSyclObjImpl(DevImagePlain);

SerializedObj ImgSpecConsts =
DeviceImageImpl->get_spec_const_blob_ref();

auto [NativePrg, DeviceCodeWasInCache] = getOrCreatePIProgram(
*BinImg, Context, Device, CompileOpts + LinkOpts, ImgSpecConsts);
assert(!DeviceCodeWasInCache &&
"we don't expect dependencies to be already cached whilst the "
"main program is not cached");
std::ignore = DeviceCodeWasInCache;

if (BinImg->supportsSpecConstants())
setSpecializationConstants(DeviceImageImpl, NativePrg, Plugin);

ProgramsToLink.push_back(NativePrg);
}
}
ProgramPtr BuiltProgram =
build(std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
getRawSyclObjImpl(Device)->getHandleRef(), DeviceLibReqMask);
getRawSyclObjImpl(Device)->getHandleRef(), DeviceLibReqMask,
ProgramsToLink);
// Those extra programs won't be used anymore, just the final linked result
for (sycl::detail::pi::PiProgram Prg : ProgramsToLink)
Plugin->call<PiApiKind::piProgramRelease>(Prg);

emitBuiltProgramInfo(BuiltProgram.get(), ContextImpl);

Expand All @@ -604,16 +729,21 @@ sycl::detail::pi::PiProgram ProgramManager::getBuiltPIProgram(
ContextImpl->addDeviceGlobalInitializer(BuiltProgram.get(), {Device}, &Img);

// Save program to persistent cache if it is not there
if (!DeviceCodeWasInCache)
if (!DeviceCodeWasInCache) {
PersistentDeviceCodeCache::putItemToDisc(
Device, Img, SpecConsts, CompileOpts + LinkOpts, BuiltProgram.get());
// Even though DeviceImagesToLink may contain device images with other
// kernels, we don't create extra on-disk cache entries for those (like we
// do for in-memory cache below) to avoid wasting disk space, because we
// expect the order of kernel execution within the app to be mostly stable
// between invocations.
}
return BuiltProgram.release();
};

uint32_t ImgId = Img.getImageID();
const sycl::detail::pi::PiDevice PiDevice = Dev->getHandleRef();
auto CacheKey =
std::make_pair(std::make_pair(std::move(SpecConsts), ImgId), PiDevice);
auto CacheKey = std::make_pair(std::make_pair(SpecConsts, ImgId), PiDevice);

auto GetCachedBuildF = [&Cache, &CacheKey]() {
return Cache.getOrInsertProgram(CacheKey);
Expand All @@ -626,12 +756,28 @@ sycl::detail::pi::PiProgram ProgramManager::getBuiltPIProgram(
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");

sycl::detail::pi::PiProgram ResProgram = BuildResult->Val;
auto Plugin = ContextImpl->getPlugin();

// If we linked any extra device images for virtual functions, then we need to
// cache them as well.
for (const RTDeviceBinaryImage *BImg : DeviceImagesToLink) {
// CacheKey is captured by reference by GetCachedBuildF, so we can simply
// update it here and re-use that lambda.
CacheKey.first.second = BImg->getImageID();
bool DidInsert = Cache.insertBuiltProgram(CacheKey, ResProgram);
if (DidInsert) {
// For every cached copy of the program, we need to increment its refcount
Plugin->call<PiApiKind::piProgramRetain>(ResProgram);
}
}

// If caching is enabled, one copy of the program handle will be
// stored in the cache, and one handle is returned to the
// caller. In that case, we need to increase the ref count of the
// program.
ContextImpl->getPlugin()->call<PiApiKind::piProgramRetain>(BuildResult->Val);
return BuildResult->Val;
Plugin->call<PiApiKind::piProgramRetain>(ResProgram);
return ResProgram;
}

// When caching is enabled, the returned PiProgram and PiKernel will
Expand Down Expand Up @@ -1168,7 +1314,8 @@ getDeviceLibPrograms(const ContextImplPtr Context,
ProgramManager::ProgramPtr ProgramManager::build(
ProgramPtr Program, const ContextImplPtr Context,
const std::string &CompileOptions, const std::string &LinkOptions,
const sycl::detail::pi::PiDevice &Device, uint32_t DeviceLibReqMask) {
const sycl::detail::pi::PiDevice &Device, uint32_t DeviceLibReqMask,
const std::vector<sycl::detail::pi::PiProgram> &ExtraProgramsToLink) {

if (DbgProgMgr > 0) {
std::cerr << ">>> ProgramManager::build(" << Program.get() << ", "
Expand All @@ -1194,7 +1341,7 @@ ProgramManager::ProgramPtr ProgramManager::build(
static bool ForceLink = ForceLinkEnv && (*ForceLinkEnv == '1');

const PluginPtr &Plugin = Context->getPlugin();
if (LinkPrograms.empty() && !ForceLink) {
if (LinkPrograms.empty() && ExtraProgramsToLink.empty() && !ForceLink) {
const std::string &Options = LinkOptions.empty()
? CompileOptions
: (CompileOptions + " " + LinkOptions);
Expand All @@ -1216,6 +1363,13 @@ ProgramManager::ProgramPtr ProgramManager::build(
nullptr, nullptr, nullptr, nullptr);
LinkPrograms.push_back(Program.get());

for (sycl::detail::pi::PiProgram Prg : ExtraProgramsToLink) {
Plugin->call<PiApiKind::piProgramCompile>(
Prg, /*num devices =*/1, &Device, CompileOptions.c_str(), 0, nullptr,
nullptr, nullptr, nullptr);
LinkPrograms.push_back(Prg);
}

sycl::detail::pi::PiProgram LinkedProg = nullptr;
auto doLink = [&] {
return Plugin->call_nocheck<PiApiKind::piProgramLink>(
Expand Down Expand Up @@ -1291,6 +1445,13 @@ void ProgramManager::addImages(pi_device_binaries DeviceBinary) {
for (const pi_device_binary_property &ExportedSymbol : ExportedSymbols)
m_ExportedSymbols.insert(ExportedSymbol->Name);

// Record mapping between virtual function sets and device images
for (const pi_device_binary_property &VFProp : Img->getVirtualFunctions()) {
std::string StrValue = DeviceBinaryProperty(VFProp).asCString();
for (const auto &SetName : detail::split_string(StrValue, ','))
m_VFSet2BinImage[SetName].insert(Img.get());
}

if (DumpImages) {
const bool NeedsSequenceID = std::any_of(
m_BinImg2KernelIDs.begin(), m_BinImg2KernelIDs.end(),
Expand Down Expand Up @@ -1954,31 +2115,6 @@ std::vector<device_image_plain> ProgramManager::getSYCLDeviceImages(
return DeviceImages;
}

static void
setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
sycl::detail::pi::PiProgram Prog,
const PluginPtr &Plugin) {
// Set ITT annotation specialization constant if needed.
enableITTAnnotationsIfNeeded(Prog, Plugin);

std::lock_guard<std::mutex> Lock{InputImpl->get_spec_const_data_lock()};
const std::map<std::string, std::vector<device_image_impl::SpecConstDescT>>
&SpecConstData = InputImpl->get_spec_const_data_ref();
const SerializedObj &SpecConsts = InputImpl->get_spec_const_blob_ref();

// Set all specialization IDs from descriptors in the input device image.
for (const auto &[SpecConstNames, SpecConstDescs] : SpecConstData) {
std::ignore = SpecConstNames;
for (const device_image_impl::SpecConstDescT &SpecIDDesc : SpecConstDescs) {
if (SpecIDDesc.IsSet) {
Plugin->call<PiApiKind::piextProgramSetSpecializationConstant>(
Prog, SpecIDDesc.ID, SpecIDDesc.Size,
SpecConsts.data() + SpecIDDesc.BlobOffset);
}
}
}
}

device_image_plain
ProgramManager::compile(const device_image_plain &DeviceImage,
const std::vector<device> &Devs,
Expand Down Expand Up @@ -2220,9 +2356,12 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,
!SYCLConfig<SYCL_DEVICELIB_NO_FALLBACK>::get())
DeviceLibReqMask = getDeviceLibReqMask(Img);

// TODO: Add support for using virtual functions with kernel bundles
std::vector<sycl::detail::pi::PiProgram> ExtraProgramsToLink;
ProgramPtr BuiltProgram =
build(std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
getRawSyclObjImpl(Devs[0])->getHandleRef(), DeviceLibReqMask);
getRawSyclObjImpl(Devs[0])->getHandleRef(), DeviceLibReqMask,
ExtraProgramsToLink);

emitBuiltProgramInfo(BuiltProgram.get(), ContextImpl);

Expand Down
Loading

0 comments on commit c7d018f

Please sign in to comment.