Skip to content

Commit

Permalink
[SYCL] Filter out unneeded device images with lower state than reques…
Browse files Browse the repository at this point in the history
…ted (#8523)

When fetching device images compatible with non-input states, we can
ignore an image if another one with a higher state is available for all
the possible kernel-device pairs. This patch adds the logic for
filtering out such unnecessary images so that we can avoid JIT
compilation if both AOT and SPIRV images are present.
  • Loading branch information
sergey-semenov authored Mar 21, 2023
1 parent 3be2e42 commit 61e5101
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 51 deletions.
130 changes: 102 additions & 28 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,46 +1683,120 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
}
assert(BinImages.size() > 0 && "Expected to find at least one device image");

// Ignore images with incompatible state. Image is considered compatible
// with a target state if an image is already in the target state or can
// be brought to target state by compiling/linking/building.
//
// Example: an image in "executable" state is not compatible with
// "input" target state - there is no operation to convert the image it
// to "input" state. An image in "input" state is compatible with
// "executable" target state because it can be built to get into
// "executable" state.
for (auto It = BinImages.begin(); It != BinImages.end();) {
if (getBinImageState(*It) > TargetState)
It = BinImages.erase(It);
else
++It;
}

std::vector<device_image_plain> SYCLDeviceImages;
for (RTDeviceBinaryImage *BinImage : BinImages) {
const bundle_state ImgState = getBinImageState(BinImage);

// Ignore images with incompatible state. Image is considered compatible
// with a target state if an image is already in the target state or can
// be brought to target state by compiling/linking/building.
//
// Example: an image in "executable" state is not compatible with
// "input" target state - there is no operation to convert the image it
// to "input" state. An image in "input" state is compatible with
// "executable" target state because it can be built to get into
// "executable" state.
if (ImgState > TargetState)
continue;

for (const sycl::device &Dev : Devs) {
// If a non-input state is requested, we can filter out some compatible
// images and return only those with the highest compatible state for each
// device-kernel pair. This map tracks how many kernel-device pairs need each
// image, so that any unneeded ones are skipped.
// TODO this has no effect if the requested state is input, consider having
// a separate branch for that case to avoid unnecessary tracking work.
struct DeviceBinaryImageInfo {
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
bundle_state State = bundle_state::input;
int RequirementCounter = 0;
};
std::unordered_map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;

for (const sycl::device &Dev : Devs) {
// Track the highest image state for each requested kernel.
using StateImagesPairT =
std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
using KernelImageMapT =
std::map<kernel_id, StateImagesPairT, LessByNameComp>;
KernelImageMapT KernelImageMap;
if (!KernelIDs.empty())
for (const kernel_id &KernelID : KernelIDs)
KernelImageMap.insert({KernelID, {}});

for (RTDeviceBinaryImage *BinImage : BinImages) {
if (!compatibleWithDevice(BinImage, Dev) ||
!doesDevSupportDeviceRequirements(Dev, *BinImage))
continue;

std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
// Collect kernel names for the image
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
KernelIDs = m_BinImg2KernelIDs[BinImage];
// If the image does not contain any non-service kernels we can skip it.
if (!KernelIDs || KernelIDs->empty())
continue;
auto InsertRes = ImageInfoMap.insert({BinImage, {}});
DeviceBinaryImageInfo &ImgInfo = InsertRes.first->second;
if (InsertRes.second) {
ImgInfo.State = getBinImageState(BinImage);
// Collect kernel names for the image
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
}
}
const bundle_state ImgState = ImgInfo.State;
const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
ImgInfo.KernelIDs;
int &ImgRequirementCounter = ImgInfo.RequirementCounter;

DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
BinImage, Ctx, Devs, ImgState, KernelIDs, /*PIProgram=*/nullptr);
// If the image does not contain any non-service kernels we can skip it.
if (!ImageKernelIDs || ImageKernelIDs->empty())
continue;

SYCLDeviceImages.push_back(
createSyclObjFromImpl<device_image_plain>(Impl));
break;
// Update tracked information.
for (kernel_id &KernelID : *ImageKernelIDs) {
StateImagesPairT *StateImagesPair;
// If only specific kernels are requested, ignore the rest.
if (!KernelIDs.empty()) {
auto It = KernelImageMap.find(KernelID);
if (It == KernelImageMap.end())
continue;
StateImagesPair = &It->second;
} else
StateImagesPair = &KernelImageMap[KernelID];

auto &[KernelImagesState, KernelImages] = *StateImagesPair;

if (KernelImages.empty()) {
KernelImagesState = ImgState;
KernelImages.push_back(BinImage);
++ImgRequirementCounter;
} else if (KernelImagesState < ImgState) {
for (RTDeviceBinaryImage *Img : KernelImages) {
auto It = ImageInfoMap.find(Img);
assert(It != ImageInfoMap.end());
assert(It->second.RequirementCounter > 0);
--(It->second.RequirementCounter);
}
KernelImages.clear();
KernelImages.push_back(BinImage);
KernelImagesState = ImgState;
++ImgRequirementCounter;
} else if (KernelImagesState == ImgState) {
KernelImages.push_back(BinImage);
++ImgRequirementCounter;
}
}
}
}

for (const auto &ImgInfoPair : ImageInfoMap) {
if (ImgInfoPair.second.RequirementCounter == 0)
continue;

DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
ImgInfoPair.first, Ctx, Devs, ImgInfoPair.second.State,
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr);

SYCLDeviceImages.push_back(createSyclObjFromImpl<device_image_plain>(Impl));
}

return SYCLDeviceImages;
}

Expand Down
1 change: 1 addition & 0 deletions sycl/unittests/SYCL2020/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_sycl_unittest(SYCL2020Tests OBJECT
GetNativeOpenCL.cpp
SpecializationConstant.cpp
KernelBundle.cpp
KernelBundleStateFiltering.cpp
KernelID.cpp
HasExtension.cpp
IsCompatible.cpp
Expand Down
213 changes: 213 additions & 0 deletions sycl/unittests/SYCL2020/KernelBundleStateFiltering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
//==---- KernelBundleStateFiltering.cpp --- Kernel bundle unit test --------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <detail/device_impl.hpp>
#include <detail/kernel_bundle_impl.hpp>
#include <sycl/sycl.hpp>

#include <helpers/MockKernelInfo.hpp>
#include <helpers/PiImage.hpp>
#include <helpers/PiMock.hpp>

#include <gtest/gtest.h>

#include <algorithm>
#include <set>
#include <vector>

class KernelA;
class KernelB;
class KernelC;
class KernelD;
class KernelE;
namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace detail {
template <> struct KernelInfo<KernelA> : public unittest::MockKernelInfoBase {
static constexpr const char *getName() { return "KernelA"; }
};
template <> struct KernelInfo<KernelB> : public unittest::MockKernelInfoBase {
static constexpr const char *getName() { return "KernelB"; }
};
template <> struct KernelInfo<KernelC> : public unittest::MockKernelInfoBase {
static constexpr const char *getName() { return "KernelC"; }
};
template <> struct KernelInfo<KernelD> : public unittest::MockKernelInfoBase {
static constexpr const char *getName() { return "KernelD"; }
};
template <> struct KernelInfo<KernelE> : public unittest::MockKernelInfoBase {
static constexpr const char *getName() { return "KernelE"; }
};
} // namespace detail
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl

namespace {

std::set<const void *> TrackedImages;
sycl::unittest::PiImage
generateDefaultImage(std::initializer_list<std::string> KernelNames,
pi_device_binary_type BinaryType,
const char *DeviceTargetSpec) {
using namespace sycl::unittest;

PiPropertySet PropSet;

static unsigned char NImage = 0;
std::vector<unsigned char> Bin{NImage++};

PiArray<PiOffloadEntry> Entries = makeEmptyKernels(KernelNames);

PiImage Img{BinaryType, // Format
DeviceTargetSpec,
"", // Compile options
"", // Link options
std::move(Bin),
std::move(Entries),
std::move(PropSet)};
const void *BinaryPtr = Img.getBinaryPtr();
TrackedImages.insert(BinaryPtr);

return Img;
}

// Image 0: input, KernelA KernelB
// Image 1: exe, KernelA
// Image 2: input, KernelC
// Image 3: exe, KernelC
// Image 4: input, KernelD
// Image 5: input, KernelE
// Image 6: exe, KernelE
// Image 7: exe. KernelE
sycl::unittest::PiImage Imgs[] = {
generateDefaultImage({"KernelA", "KernelB"}, PI_DEVICE_BINARY_TYPE_SPIRV,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
generateDefaultImage({"KernelA"}, PI_DEVICE_BINARY_TYPE_NATIVE,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_SPIRV,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_NATIVE,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
generateDefaultImage({"KernelD"}, PI_DEVICE_BINARY_TYPE_SPIRV,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_SPIRV,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_NATIVE,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_NATIVE,
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64)};

sycl::unittest::PiImageArray<std::size(Imgs)> ImgArray{Imgs};
std::vector<unsigned char> UsedImageIndices;

void redefinedPiProgramCreateCommon(const void *bin) {
if (TrackedImages.count(bin) != 0) {
unsigned char ImgIdx = *reinterpret_cast<const unsigned char *>(bin);
UsedImageIndices.push_back(ImgIdx);
}
}

pi_result redefinedPiProgramCreate(pi_context context, const void *il,
size_t length, pi_program *res_program) {
redefinedPiProgramCreateCommon(il);
return PI_SUCCESS;
}

pi_result redefinedPiProgramCreateWithBinary(
pi_context context, pi_uint32 num_devices, const pi_device *device_list,
const size_t *lengths, const unsigned char **binaries,
size_t num_metadata_entries, const pi_device_binary_property *metadata,
pi_int32 *binary_status, pi_program *ret_program) {
redefinedPiProgramCreateCommon(binaries[0]);
return PI_SUCCESS;
}

pi_result redefinedDevicesGet(pi_platform platform, pi_device_type device_type,
pi_uint32 num_entries, pi_device *devices,
pi_uint32 *num_devices) {
if (num_devices)
*num_devices = 2;

if (devices) {
devices[0] = reinterpret_cast<pi_device>(1);
devices[1] = reinterpret_cast<pi_device>(2);
}

return PI_SUCCESS;
}

pi_result redefinedExtDeviceSelectBinary(pi_device device,
pi_device_binary *binaries,
pi_uint32 num_binaries,
pi_uint32 *selected_binary_ind) {
EXPECT_EQ(num_binaries, 1U);
// Treat image 3 as incompatible with one of the devices.
if (TrackedImages.count(binaries[0]->BinaryStart) != 0 &&
*binaries[0]->BinaryStart == 3 &&
device == reinterpret_cast<pi_device>(2)) {
return PI_ERROR_INVALID_BINARY;
}
*selected_binary_ind = 0;
return PI_SUCCESS;
}

void verifyImageUse(const std::vector<unsigned char> &ExpectedImages) {
std::sort(UsedImageIndices.begin(), UsedImageIndices.end());
EXPECT_TRUE(std::is_sorted(ExpectedImages.begin(), ExpectedImages.end()));
EXPECT_EQ(UsedImageIndices, ExpectedImages);
UsedImageIndices.clear();
}

TEST(KernelBundle, DeviceImageStateFiltering) {
sycl::unittest::PiMock Mock;
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreate>(
redefinedPiProgramCreate);
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreateWithBinary>(
redefinedPiProgramCreateWithBinary);

// No kernel ids specified.
{
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
sycl::context Ctx{Dev};

sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(Ctx, {Dev});
verifyImageUse({0, 1, 3, 4, 6, 7});
}

sycl::kernel_id KernelAID = sycl::get_kernel_id<KernelA>();
sycl::kernel_id KernelCID = sycl::get_kernel_id<KernelC>();
sycl::kernel_id KernelDID = sycl::get_kernel_id<KernelD>();

// Request specific kernel ids.
{
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
sycl::context Ctx{Dev};

sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
Ctx, {Dev}, {KernelAID, KernelCID, KernelDID});
verifyImageUse({1, 3, 4});
}

// Check the case where some executable images are unsupported by one of
// the devices.
{
Mock.redefine<sycl::detail::PiApiKind::piDevicesGet>(redefinedDevicesGet);
Mock.redefine<sycl::detail::PiApiKind::piextDeviceSelectBinary>(
redefinedExtDeviceSelectBinary);
const std::vector<sycl::device> Devs = Mock.getPlatform().get_devices();
sycl::context Ctx{Devs};

sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
Ctx, Devs, {KernelAID, KernelCID, KernelDID});
verifyImageUse({1, 2, 3, 4});
}
}
} // namespace
28 changes: 28 additions & 0 deletions sycl/unittests/helpers/MockKernelInfo.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <sycl/detail/kernel_desc.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace unittest {
struct MockKernelInfoBase {
static constexpr unsigned getNumParams() { return 0; }
static const detail::kernel_param_desc_t &getParamDesc(int) {
static detail::kernel_param_desc_t Dummy;
return Dummy;
}
static constexpr bool isESIMD() { return false; }
static constexpr bool callsThisItem() { return false; }
static constexpr bool callsAnyThisFreeFunction() { return false; }
static constexpr int64_t getKernelSize() { return 1; }
};

} // namespace unittest
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
Loading

0 comments on commit 61e5101

Please sign in to comment.