Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuDNN] Add initial implementation of cuDNN custom module #123

Merged
merged 4 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openxla-nvgpu/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
build/

# Source indexing files
compile_commands.json
.cache/clangd
.clangd/
19 changes: 19 additions & 0 deletions openxla-nvgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

cmake_minimum_required(VERSION 3.21...3.24)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

project(OPENXLA_NVGPU)

# TODO: Fix this once the project is slotted into place.
Expand All @@ -22,5 +24,22 @@ option(IREE_TARGET_BACKEND_DEFAULTS "Disables target backend" OFF)
option(IREE_TARGET_BACKEND_CUDA "Enables CUDA target backend" ON)
option(IREE_COMPILER_BUILD_SHARED_LIBS "Enables shared libraries in the compiler by default" ON)

# TODO: `llvm-cpu` compiler backend and `local-sync` HAL driver are enabled only
# for running tests. Diable them once we'll use proper CUDA target and driver.
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ON)
option(IREE_TARGET_BACKEND_LLVM_CPU "Enables the 'llvm-cpu' target backend" ON)

set(IREE_COMPILER_PLUGIN_PATHS "${CMAKE_CURRENT_SOURCE_DIR}" CACHE STRING "OpenXLA nvgpu plugins")
add_subdirectory("${IREE_ROOT_DIR}" "iree_core")

#-------------------------------------------------------------------------------
# OpenXLA NVGPU Runtime.
#
# Integration of NVIDIA libraries with IREE runtime via custom VM modules.
#-------------------------------------------------------------------------------

# TODO: Use same compiler flags for building runtime targets as the IREE core.
set(IREE_CXX_STANDARD 17)

add_subdirectory(runtime)
add_subdirectory(tools)
7 changes: 7 additions & 0 deletions openxla-nvgpu/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(src)
11 changes: 11 additions & 0 deletions openxla-nvgpu/runtime/src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}")
set(IREE_PACKAGE_ROOT_PREFIX "")
set(IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}")

add_subdirectory(openxla/runtime/nvgpu)
29 changes: 29 additions & 0 deletions openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 The IREE Authors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to drive this from bazel_to_cmake (which will also aid downstream integration). Fine to come back and do that as a followup if not sure that it works (I had to make a number of tweaks to it).

#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

add_subdirectory(test)

iree_cc_library(
NAME
defs
INCLUDES
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../..>"
"$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/../../..>"
PUBLIC
)

iree_cc_library(
NAME
cudnn_module
HDRS
"cudnn_module.h"
SRCS
"cudnn_module.cpp"
DEPS
::defs
iree::runtime
PUBLIC
)
67 changes: 67 additions & 0 deletions openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/cudnn_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "openxla/runtime/nvgpu/cudnn_module.h"

#include <iree/base/status_cc.h>

#include <cstdio>

#include "iree/vm/native_module_cc.h"

namespace openxla::runtime::nvgpu {

using namespace iree;

//===----------------------------------------------------------------------===//
// CuDNN module state encapsulates all the state required for running cuDNN
// operations (launching cuDNN graphs on a stream) at run time.
//===----------------------------------------------------------------------===//

class CuDNNModuleState {
public:
Status Hello() {
fprintf(stderr, "Hello from OpenXLA CuDNN Module!\n");
return OkStatus();
}
};

static const vm::NativeFunction<CuDNNModuleState> kCuDNNModuleFunctions[] = {
vm::MakeNativeFunction("hello", &CuDNNModuleState::Hello),
};

//===----------------------------------------------------------------------===//
// CuDNN module instance that will be allocated and reused across contexts.
//===----------------------------------------------------------------------===//

class CuDNNModule final : public vm::NativeModule<CuDNNModuleState> {
public:
using vm::NativeModule<CuDNNModuleState>::NativeModule;

StatusOr<std::unique_ptr<CuDNNModuleState>> CreateState(
iree_allocator_t host_allocator) override {
return std::make_unique<CuDNNModuleState>();
}
};

} // namespace openxla::runtime::nvgpu

//===----------------------------------------------------------------------===//
// Register cuDNN module with IREE runtime.
//===----------------------------------------------------------------------===//

using namespace openxla::runtime::nvgpu;

extern "C" iree_status_t iree_custom_module_cudnn_create(
iree_vm_instance_t* instance, iree_hal_device_t* device,
iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(out_module);
auto module = std::make_unique<CuDNNModule>(
"cudnn", /*version=*/0, instance, host_allocator,
span<const vm::NativeFunction<CuDNNModuleState>>(kCuDNNModuleFunctions));
*out_module = module.release()->interface();
return iree_ok_status();
}
27 changes: 27 additions & 0 deletions openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/cudnn_module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_
#define OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_

#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/vm/api.h"

#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

iree_status_t iree_custom_module_cudnn_create(iree_vm_instance_t* instance,
iree_hal_device_t* device,
iree_allocator_t host_allocator,
iree_vm_module_t** out_module);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif // OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

iree_lit_test_suite(
NAME
lit
SRCS
"example.mlir"
TOOLS
FileCheck
iree-compile
openxla-runner
LABELS
"hostonly"
)
13 changes: 13 additions & 0 deletions openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/test/example.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu | openxla-runner - example.main | FileCheck %s

module @example {

func.func private @cudnn.hello()

func.func @main() {
// CHECK: Hello from OpenXLA CuDNN Module!
call @cudnn.hello() : () -> ()
return
}

}
25 changes: 25 additions & 0 deletions openxla-nvgpu/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

iree_cc_binary(
NAME
openxla-runner
SRCS
"openxla-runner.c"
DEPS
iree::base
iree::base::internal::flags
iree::base::tracing
iree::hal
iree::modules::hal::types
iree::tooling::comparison
iree::tooling::context_util
iree::tooling::device_util
iree::tooling::instrument_util
iree::tooling::vm_util
iree::vm
openxla::runtime::nvgpu::cudnn_module
)
97 changes: 97 additions & 0 deletions openxla-nvgpu/tools/openxla-runner.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <stdio.h>

#include "iree/modules/hal/types.h"
#include "iree/runtime/api.h"
#include "openxla/runtime/nvgpu/cudnn_module.h"

// TODO: This is a temporary work around missing custom modules integration into
// IREE tools (iree-run-module). We already have flags to enable plugins in
// compiler tools (`iree-compiler` and `iree-opt`), but not yet in "runtime"
// tools. This tool can only run VM function with empty arguments and empty
// results, and intended for testing cuDNN custom module.
int main(int argc, char** argv) {
if (argc != 3) {
fprintf(stderr,
"Usage:\n"
" openxla-runner - <entry.point> # read from stdin\n"
" openxla-runner </path/to/say_hello.vmfb> "
"<entry.point>\n");
return -1;
}

// Internally IREE does not (in general) use malloc and instead uses the
// provided allocator to allocate and free memory. Applications can integrate
// their own allocator as-needed.
iree_allocator_t host_allocator = iree_allocator_system();

// Create and configure the instance shared across all sessions.
iree_runtime_instance_options_t instance_options;
iree_runtime_instance_options_initialize(&instance_options);
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
iree_runtime_instance_t* instance = NULL;
IREE_CHECK_OK(iree_runtime_instance_create(&instance_options, host_allocator,
&instance));

// Try to create the device - it should always succeed as it's a CPU device.
iree_hal_device_t* device = NULL;
IREE_CHECK_OK(iree_runtime_instance_try_create_default_device(
instance, iree_make_cstring_view("local-sync"), &device));

// Create one session per loaded module to hold the module state.
iree_runtime_session_options_t session_options;
iree_runtime_session_options_initialize(&session_options);
iree_runtime_session_t* session = NULL;
IREE_CHECK_OK(iree_runtime_session_create_with_device(
instance, &session_options, device,
iree_runtime_instance_host_allocator(instance), &session));

// Create the custom module that can be reused across contexts.
iree_vm_module_t* custom_module = NULL;
IREE_CHECK_OK(iree_custom_module_cudnn_create(
iree_runtime_instance_vm_instance(instance), device, host_allocator,
&custom_module));
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
iree_vm_module_release(custom_module);

// Load the module from stdin or a file on disk.
const char* module_path = argv[1];
if (strcmp(module_path, "-") == 0) {
IREE_CHECK_OK(
iree_runtime_session_append_bytecode_module_from_stdin(session));
} else {
IREE_CHECK_OK(iree_runtime_session_append_bytecode_module_from_file(
session, module_path));
}

iree_string_view_t entry_point = iree_make_cstring_view(argv[2]);
fprintf(stdout, "INVOKE BEGIN %.*s\n", (int)entry_point.size,
entry_point.data);
fflush(stdout);

iree_vm_list_t* inputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(NULL, 1, host_allocator, &inputs));
iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(NULL, 1, host_allocator, &outputs));

// Synchronously invoke the requested function.
IREE_CHECK_OK(
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));

fprintf(stdout, "INVOKE END %.*s\n", (int)entry_point.size, entry_point.data);
fflush(stdout);

iree_vm_list_release(inputs);
iree_vm_list_release(outputs);

iree_runtime_session_release(session);
iree_hal_device_release(device);
iree_runtime_instance_release(instance);

return 0;
}