Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] PoC implementation of kernel compiler extension with libtooling and sycl-jit #15701

Open
wants to merge 14 commits into
base: sycl
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ struct SYCLKernelInfo {
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
};

struct InMemoryFile {
const char *Path;
const char *Contents;
};

} // namespace jit_compiler

#endif // SYCL_FUSION_COMMON_KERNEL_H
11 changes: 11 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_llvm_library(sycl-jit
lib/fusion/FusionHelper.cpp
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/rtc/DeviceCompilation.cpp
lib/helper/ConfigHelper.cpp

SHARED
Expand All @@ -29,6 +30,14 @@ add_llvm_library(sycl-jit
TargetParser
MC
${LLVM_TARGETS_TO_BUILD}

LINK_LIBS
clangBasic
clangDriver
clangFrontend
clangCodeGen
clangTooling
clangSerialization
)

target_compile_options(sycl-jit PRIVATE ${SYCL_JIT_WARNING_FLAGS})
Expand All @@ -40,6 +49,8 @@ target_include_directories(sycl-jit
SYSTEM PRIVATE
${LLVM_MAIN_INCLUDE_DIR}
${LLVM_SPIRV_INCLUDE_DIRS}
${LLVM_EXTERNAL_CLANG_SOURCE_DIR}/include
${CMAKE_BINARY_DIR}/tools/clang/include
)
target_include_directories(sycl-jit
PUBLIC
Expand Down
3 changes: 3 additions & 0 deletions sycl-jit/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ JITResult materializeSpecConstants(const char *KernelName,
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob);

JITResult compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);

/// Clear all previously set options.
void resetJITConfiguration();

Expand Down
1 change: 1 addition & 0 deletions sycl-jit/jit-compiler/ld-version-script.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/* Export the library entry points */
fuseKernels;
materializeSpecConstants;
compileSYCL;
resetJITConfiguration;
addToJITConfiguration;

Expand Down
26 changes: 26 additions & 0 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "fusion/FusionPipeline.h"
#include "helper/ConfigHelper.h"
#include "helper/ErrorHandling.h"
#include "rtc/DeviceCompilation.h"
#include "translation/KernelTranslation.h"
#include "translation/SPIRVLLVMTranslation.h"
#include <llvm/Support/Error.h>
Expand Down Expand Up @@ -235,6 +236,31 @@ extern "C" JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
return JITResult{FusedKernelInfo};
}

extern "C" JITResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgs);
if (!ModuleOrErr) {
return errorToFusionResult(ModuleOrErr.takeError(),
"Device compilation failed");
}
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);

SYCLKernelInfo Kernel;
auto Error = translation::KernelTranslator::translateKernel(
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV);

auto *LLVMCtx = &Module->getContext();
Module.reset();
delete LLVMCtx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How expensive is it to set up and destroy the LLVMContext on every call to RTC? Would it be an alternative to use the context from JITContext?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, certain things like constants and metadata are stored within LLVMContext and won't be deallocated even if a module which references them is deallocated. Therefore, keeping LLVMContext between RTC call invocations could result in some memory build-up.

At least that is the behavior we discovered a few years ago when we were debugging exceptionally huge memory footprint of sycl-post-link where we had a huge codebase compiled with per-kernel device code split. I don't know if anything has changed since that in the upstream LLVM, but we hadn't proposed any patches back then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly possible to pass an existing context into the ToolAction, but that also raises questions of thread safety.

For the performance implications, a) yes, looks like setting up a context does involve a non-trivial amount of work, and b) still seems true that types, constants and metadata are allocated in the context and not freed when the module is destroyed. I'd propose to keep the simple implementation for now, and will look out for the context setup overhead once we start benchmarking the RTC.


if (Error) {
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
}

return JITResult{Kernel};
}

extern "C" void resetJITConfiguration() { ConfigHelper::reset(); }

extern "C" void addToJITConfiguration(OptionStorage &&Opt) {
Expand Down
149 changes: 149 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//==---------------------- DeviceCompilation.cpp ---------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DeviceCompilation.h"

#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Tooling/CompilationDatabase.h>
#include <clang/Tooling/Tooling.h>

#include <llvm/IR/Module.h>

#ifdef _GNU_SOURCE
#include <dlfcn.h>
static char X; // Dummy symbol, used as an anchor for `dlinfo` below.
#endif // _GNU_SOURCE

static constexpr auto InvalidDPCPPRoot = "<invalid>";
static constexpr auto JITLibraryPathSuffix = "/lib/libsycl-jit.so";

static const std::string &getDPCPPRoot() {
thread_local std::string DPCPPRoot;

if (!DPCPPRoot.empty()) {
return DPCPPRoot;
}
DPCPPRoot = InvalidDPCPPRoot;

#ifdef _GNU_SOURCE
Dl_info Info;
if (dladdr(&X, &Info)) {
std::string LoadedLibraryPath = Info.dli_fname;
auto Pos = LoadedLibraryPath.rfind(JITLibraryPathSuffix);
if (Pos != std::string::npos) {
DPCPPRoot = LoadedLibraryPath.substr(0, Pos);
}
}
#endif // _GNU_SOURCE

// TODO: Implemenent other means of determining the DPCPP root, e.g.
// evaluating the `CMPLR_ROOT` env.

return DPCPPRoot;
}

namespace {
using namespace clang;
using namespace clang::tooling;
using namespace clang::driver;

struct GetLLVMModuleAction : public ToolAction {
// Code adapted from `FrontendActionFactory::runInvocation`.
bool runInvocation(std::shared_ptr<CompilerInvocation> Invocation,
FileManager *Files,
std::shared_ptr<PCHContainerOperations> PCHContainerOps,
DiagnosticConsumer *DiagConsumer) override {
assert(!Module && "Action should only be invoked on a single file");

// Create a compiler instance to handle the actual work.
CompilerInstance Compiler(std::move(PCHContainerOps));
Compiler.setInvocation(std::move(Invocation));
Compiler.setFileManager(Files);

// Create the compiler's actual diagnostics engine.
Compiler.createDiagnostics(DiagConsumer, /*ShouldOwnClient=*/false);
if (!Compiler.hasDiagnostics()) {
return false;
}

Compiler.createSourceManager(*Files);

// Ignore `Compiler.getFrontendOpts().ProgramAction` (would be `EmitBC`) and
// create/execute an `EmitLLVMOnlyAction` (= codegen to LLVM module without
// emitting anything) instead.
EmitLLVMOnlyAction ELOA;
const bool Success = Compiler.ExecuteAction(ELOA);
Files->clearStatCache();
if (!Success) {
return false;
}

// Take the module and its context to extend the objects' lifetime.
Module = ELOA.takeModule();
ELOA.takeLLVMContext();

return true;
}

std::unique_ptr<llvm::Module> Module;
};

} // anonymous namespace

llvm::Expected<std::unique_ptr<llvm::Module>>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return llvm::createStringError("Could not locate DPCPP root directory");
}

SmallVector<std::string> CommandLine = {"-fsycl-device-only"};
// TODO: Allow instrumentation again when device library linking is
// implemented.
CommandLine.push_back("-fno-sycl-instrument-device-code");
CommandLine.append(UserArgs.begin(), UserArgs.end());
clang::tooling::FixedCompilationDatabase DB{".", CommandLine};

clang::tooling::ClangTool Tool{DB, {SourceFile.Path}};

// Set up in-memory filesystem.
Tool.mapVirtualFile(SourceFile.Path, SourceFile.Contents);
for (const auto &IF : IncludeFiles) {
Tool.mapVirtualFile(IF.Path, IF.Contents);
}

// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
// default by this API.
Tool.clearArgumentsAdjusters();
// Then, modify argv[0] and set the resource directory so that the driver
// picks up the correct SYCL environment.
Tool.appendArgumentsAdjuster(
[&DPCPPRoot](const CommandLineArguments &Args,
StringRef Filename) -> CommandLineArguments {
(void)Filename;
CommandLineArguments NewArgs = Args;
NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str();
NewArgs.push_back((Twine("-resource-dir=") + DPCPPRoot + "/lib/clang/" +
Twine(CLANG_VERSION_MAJOR))
.str());
Comment on lines +136 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these paths also apply in a packaged release?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a release icpx follows the same path structure (checked with -print-resource-dir).

return NewArgs;
});

GetLLVMModuleAction Action;
if (!Tool.run(&Action)) {
return std::move(Action.Module);
}

// TODO: Capture compiler errors from the ClangTool.
return llvm::createStringError("Unable to obtain LLVM module");
}
31 changes: 31 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//==---- DeviceCompilation.h - Compile SYCL device code with libtooling ----==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H
#define SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H

#include "Kernel.h"
#include "View.h"

#include <llvm/Support/Error.h>

#include <memory>

namespace llvm {
class Module;
} // namespace llvm

namespace jit_compiler {

llvm::Expected<std::unique_ptr<llvm::Module>>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);

} // namespace jit_compiler

#endif // SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ SPIRV::TranslatorOpts &SPIRVLLVMTranslator::translatorOpts() {
// there's currently no obvious way to iterate the
// array of extensions in KernelInfo.
TransOpt.enableAllExtensions();
// TODO: Remove this workaround.
TransOpt.setAllowedToUseExtension(
SPIRV::ExtensionID::SPV_KHR_untyped_pointers, false);
TransOpt.setDesiredBIsRepresentation(
SPIRV::BIsRepresentation::SPIRVFriendlyIR);
// TODO: We need to take care of specialization constants, either by
Expand Down
8 changes: 7 additions & 1 deletion sycl/include/sycl/kernel_bundle_enums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ enum class bundle_state : char {

namespace ext::oneapi::experimental {

enum class source_language : int { opencl = 0, spirv = 1, sycl = 2 /* cuda */ };
enum class source_language : int {
opencl = 0,
spirv = 1,
sycl = 2,
/* cuda */
sycl_jit /* temporary, alternative implementation for SYCL */
};

// opencl versions
struct cl_version {
Expand Down
54 changes: 54 additions & 0 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ jit_compiler::jit_compiler() {
return false;
}

this->CompileSYCLHandle = reinterpret_cast<CompileSYCLFuncT>(
sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr, "compileSYCL"));
if (!this->CompileSYCLHandle) {
printPerformanceWarning(
"Cannot resolve JIT library function entry point");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds more serious than a mere performance warning :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printPerformanceWarning is the generic error message helper in sycl-jit, but yes, I agree it's a bit of a misnomer when used here (and while attempting to set-up the other entrypoints before).

return false;
}

return true;
};
Available = checkJITLibrary();
Expand Down Expand Up @@ -1143,6 +1151,52 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
return Encoded;
}

std::vector<uint8_t> jit_compiler::compileSYCL(
const std::string &Id, const std::string &SYCLSource,
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr,
const std::vector<std::string> &RegisteredKernelNames) {

// TODO: Handle template instantiation.
if (!RegisteredKernelNames.empty()) {
throw sycl::exception(
sycl::errc::build,
"Property `sycl::ext::oneapi::experimental::registered_kernel_names` "
"is not yet supported for the `sycl_jit` source language");
}

std::string SYCLFileName = Id + ".cpp";
::jit_compiler::InMemoryFile SourceFile{SYCLFileName.c_str(),
SYCLSource.c_str()};

std::vector<::jit_compiler::InMemoryFile> IncludeFilesView;
IncludeFilesView.reserve(IncludePairs.size());
std::transform(IncludePairs.begin(), IncludePairs.end(),
std::back_inserter(IncludeFilesView), [](const auto &Pair) {
return ::jit_compiler::InMemoryFile{Pair.first.c_str(),
Pair.second.c_str()};
});
std::vector<const char *> UserArgsView;
UserArgsView.reserve(UserArgs.size());
std::transform(UserArgs.begin(), UserArgs.end(),
std::back_inserter(UserArgsView),
[](const auto &Arg) { return Arg.c_str(); });

auto Result = CompileSYCLHandle(SourceFile, IncludeFilesView, UserArgsView);

if (Result.failed()) {
throw sycl::exception(sycl::errc::build, Result.getErrorMessage());
}

// TODO: We currently don't have a meaningful build log.
(void)LogPtr;

const auto &BI = Result.getKernelInfo().BinaryInfo;
assert(BI.Format == ::jit_compiler::BinaryFormat::SPIRV);
std::vector<uint8_t> SPV(BI.BinaryStart, BI.BinaryStart + BI.BinarySize);
return SPV;
}

} // namespace detail
} // namespace _V1
} // namespace sycl
Expand Down
Loading
Loading