-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cuDNN] Add initial implementation of cuDNN custom module #123
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,6 @@ | ||
build/ | ||
|
||
# Source indexing files | ||
compile_commands.json | ||
.cache/clangd | ||
.clangd/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
add_subdirectory(src) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}") | ||
set(IREE_PACKAGE_ROOT_PREFIX "") | ||
set(IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}") | ||
|
||
add_subdirectory(openxla/runtime/nvgpu) |
29 changes: 29 additions & 0 deletions
29
openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
add_subdirectory(test) | ||
|
||
iree_cc_library( | ||
NAME | ||
defs | ||
INCLUDES | ||
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../..>" | ||
"$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/../../..>" | ||
PUBLIC | ||
) | ||
|
||
iree_cc_library( | ||
NAME | ||
cudnn_module | ||
HDRS | ||
"cudnn_module.h" | ||
SRCS | ||
"cudnn_module.cpp" | ||
DEPS | ||
::defs | ||
iree::runtime | ||
PUBLIC | ||
) |
67 changes: 67 additions & 0 deletions
67
openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/cudnn_module.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "openxla/runtime/nvgpu/cudnn_module.h" | ||
|
||
#include <iree/base/status_cc.h> | ||
|
||
#include <cstdio> | ||
|
||
#include "iree/vm/native_module_cc.h" | ||
|
||
namespace openxla::runtime::nvgpu { | ||
|
||
using namespace iree; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// CuDNN module state encapsulates all the state required for running cuDNN | ||
// operations (launching cuDNN graphs on a stream) at run time. | ||
//===----------------------------------------------------------------------===// | ||
|
||
class CuDNNModuleState { | ||
public: | ||
Status Hello() { | ||
fprintf(stderr, "Hello from OpenXLA CuDNN Module!\n"); | ||
return OkStatus(); | ||
} | ||
}; | ||
|
||
static const vm::NativeFunction<CuDNNModuleState> kCuDNNModuleFunctions[] = { | ||
vm::MakeNativeFunction("hello", &CuDNNModuleState::Hello), | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// CuDNN module instance that will be allocated and reused across contexts. | ||
//===----------------------------------------------------------------------===// | ||
|
||
class CuDNNModule final : public vm::NativeModule<CuDNNModuleState> { | ||
public: | ||
using vm::NativeModule<CuDNNModuleState>::NativeModule; | ||
|
||
StatusOr<std::unique_ptr<CuDNNModuleState>> CreateState( | ||
iree_allocator_t host_allocator) override { | ||
return std::make_unique<CuDNNModuleState>(); | ||
} | ||
}; | ||
|
||
} // namespace openxla::runtime::nvgpu | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Register cuDNN module with IREE runtime. | ||
//===----------------------------------------------------------------------===// | ||
|
||
using namespace openxla::runtime::nvgpu; | ||
|
||
extern "C" iree_status_t iree_custom_module_cudnn_create( | ||
iree_vm_instance_t* instance, iree_hal_device_t* device, | ||
iree_allocator_t host_allocator, iree_vm_module_t** out_module) { | ||
IREE_ASSERT_ARGUMENT(out_module); | ||
auto module = std::make_unique<CuDNNModule>( | ||
"cudnn", /*version=*/0, instance, host_allocator, | ||
span<const vm::NativeFunction<CuDNNModuleState>>(kCuDNNModuleFunctions)); | ||
*out_module = module.release()->interface(); | ||
return iree_ok_status(); | ||
} |
27 changes: 27 additions & 0 deletions
27
openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/cudnn_module.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_ | ||
#define OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_ | ||
|
||
#include "iree/base/api.h" | ||
#include "iree/hal/api.h" | ||
#include "iree/vm/api.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif // __cplusplus | ||
|
||
iree_status_t iree_custom_module_cudnn_create(iree_vm_instance_t* instance, | ||
iree_hal_device_t* device, | ||
iree_allocator_t host_allocator, | ||
iree_vm_module_t** out_module); | ||
|
||
#ifdef __cplusplus | ||
} // extern "C" | ||
#endif // __cplusplus | ||
|
||
#endif // OPENXLA_RUNTIME_NVGPU_CUDNN_MODULE_H_ |
18 changes: 18 additions & 0 deletions
18
openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/test/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2022 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
iree_lit_test_suite( | ||
NAME | ||
lit | ||
SRCS | ||
"example.mlir" | ||
TOOLS | ||
FileCheck | ||
iree-compile | ||
openxla-runner | ||
LABELS | ||
"hostonly" | ||
) |
13 changes: 13 additions & 0 deletions
13
openxla-nvgpu/runtime/src/openxla/runtime/nvgpu/test/example.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu | openxla-runner - example.main | FileCheck %s | ||
|
||
module @example { | ||
|
||
func.func private @cudnn.hello() | ||
|
||
func.func @main() { | ||
// CHECK: Hello from OpenXLA CuDNN Module! | ||
call @cudnn.hello() : () -> () | ||
return | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
iree_cc_binary( | ||
NAME | ||
openxla-runner | ||
SRCS | ||
"openxla-runner.c" | ||
DEPS | ||
iree::base | ||
iree::base::internal::flags | ||
iree::base::tracing | ||
iree::hal | ||
iree::modules::hal::types | ||
iree::tooling::comparison | ||
iree::tooling::context_util | ||
iree::tooling::device_util | ||
iree::tooling::instrument_util | ||
iree::tooling::vm_util | ||
iree::vm | ||
openxla::runtime::nvgpu::cudnn_module | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <stdio.h> | ||
|
||
#include "iree/modules/hal/types.h" | ||
#include "iree/runtime/api.h" | ||
#include "openxla/runtime/nvgpu/cudnn_module.h" | ||
|
||
// TODO: This is a temporary work around missing custom modules integration into | ||
// IREE tools (iree-run-module). We already have flags to enable plugins in | ||
// compiler tools (`iree-compiler` and `iree-opt`), but not yet in "runtime" | ||
// tools. This tool can only run VM function with empty arguments and empty | ||
// results, and intended for testing cuDNN custom module. | ||
int main(int argc, char** argv) { | ||
if (argc != 3) { | ||
fprintf(stderr, | ||
"Usage:\n" | ||
" openxla-runner - <entry.point> # read from stdin\n" | ||
" openxla-runner </path/to/say_hello.vmfb> " | ||
"<entry.point>\n"); | ||
return -1; | ||
} | ||
|
||
// Internally IREE does not (in general) use malloc and instead uses the | ||
// provided allocator to allocate and free memory. Applications can integrate | ||
// their own allocator as-needed. | ||
iree_allocator_t host_allocator = iree_allocator_system(); | ||
|
||
// Create and configure the instance shared across all sessions. | ||
iree_runtime_instance_options_t instance_options; | ||
iree_runtime_instance_options_initialize(&instance_options); | ||
iree_runtime_instance_options_use_all_available_drivers(&instance_options); | ||
iree_runtime_instance_t* instance = NULL; | ||
IREE_CHECK_OK(iree_runtime_instance_create(&instance_options, host_allocator, | ||
&instance)); | ||
|
||
// Try to create the device - it should always succeed as it's a CPU device. | ||
iree_hal_device_t* device = NULL; | ||
IREE_CHECK_OK(iree_runtime_instance_try_create_default_device( | ||
instance, iree_make_cstring_view("local-sync"), &device)); | ||
|
||
// Create one session per loaded module to hold the module state. | ||
iree_runtime_session_options_t session_options; | ||
iree_runtime_session_options_initialize(&session_options); | ||
iree_runtime_session_t* session = NULL; | ||
IREE_CHECK_OK(iree_runtime_session_create_with_device( | ||
instance, &session_options, device, | ||
iree_runtime_instance_host_allocator(instance), &session)); | ||
|
||
// Create the custom module that can be reused across contexts. | ||
iree_vm_module_t* custom_module = NULL; | ||
IREE_CHECK_OK(iree_custom_module_cudnn_create( | ||
iree_runtime_instance_vm_instance(instance), device, host_allocator, | ||
&custom_module)); | ||
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module)); | ||
iree_vm_module_release(custom_module); | ||
|
||
// Load the module from stdin or a file on disk. | ||
const char* module_path = argv[1]; | ||
if (strcmp(module_path, "-") == 0) { | ||
IREE_CHECK_OK( | ||
iree_runtime_session_append_bytecode_module_from_stdin(session)); | ||
} else { | ||
IREE_CHECK_OK(iree_runtime_session_append_bytecode_module_from_file( | ||
session, module_path)); | ||
} | ||
|
||
iree_string_view_t entry_point = iree_make_cstring_view(argv[2]); | ||
fprintf(stdout, "INVOKE BEGIN %.*s\n", (int)entry_point.size, | ||
entry_point.data); | ||
fflush(stdout); | ||
|
||
iree_vm_list_t* inputs = NULL; | ||
IREE_CHECK_OK(iree_vm_list_create(NULL, 1, host_allocator, &inputs)); | ||
iree_vm_list_t* outputs = NULL; | ||
IREE_CHECK_OK(iree_vm_list_create(NULL, 1, host_allocator, &outputs)); | ||
|
||
// Synchronously invoke the requested function. | ||
IREE_CHECK_OK( | ||
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs)); | ||
|
||
fprintf(stdout, "INVOKE END %.*s\n", (int)entry_point.size, entry_point.data); | ||
fflush(stdout); | ||
|
||
iree_vm_list_release(inputs); | ||
iree_vm_list_release(outputs); | ||
|
||
iree_runtime_session_release(session); | ||
iree_hal_device_release(device); | ||
iree_runtime_instance_release(instance); | ||
|
||
return 0; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to drive this from bazel_to_cmake (which will also aid downstream integration). Fine to come back and do that as a followup if not sure that it works (I had to make a number of tweaks to it).