Skip to content

Commit

Permalink
[Mosaic GPU] Load LLVM lowering interfaces for all dialects
Browse files Browse the repository at this point in the history
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
  • Loading branch information
apaszke authored and jax authors committed Jun 10, 2024
1 parent 2ade7e7 commit 3b4039c
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 6 deletions.
24 changes: 21 additions & 3 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,27 @@ py_library(

cc_library(
name = "passes",
srcs = ["launch_lowering.cc"],
hdrs = ["launch_lowering.h"],
srcs = [
"launch_lowering.cc",
"passes.cc",
],
hdrs = [
"launch_lowering.h",
"pass_boilerplate.h",
"passes.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@com_google_absl//absl/log",
"@llvm-project//mlir:TransformUtils",
],
)

Expand Down Expand Up @@ -97,29 +106,38 @@ cc_library(
":passes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:ExecutionEngineUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexToLLVM",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MathToLLVM",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVVMTarget",
"@llvm-project//mlir:NVVMToLLVM",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:UBToLLVM",
"@llvm-project//mlir:VectorDialect",
"@xla//xla/service:custom_call_status",
"@xla//xla/service:custom_call_target_registry",
Expand Down
23 changes: 22 additions & 1 deletion jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ limitations under the License.
#include "llvm/include/llvm/ADT/SmallVector.h"
#include "llvm/include/llvm/Support/CodeGen.h"
#include "llvm/include/llvm/Support/TargetSelect.h"
#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/include/mlir/Conversion/Passes.h"
#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
Expand Down Expand Up @@ -67,6 +76,7 @@ limitations under the License.
#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/include/mlir/Transforms/Passes.h"
#include "jaxlib/mosaic/gpu/launch_lowering.h"
#include "jaxlib/mosaic/gpu/passes.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"

Expand Down Expand Up @@ -100,6 +110,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::memref::registerMemRefPasses();
mlir::registerGPUPasses();
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerConvertGpuToLLVMPass();
return true;
}();
(void)register_once;
Expand All @@ -123,7 +134,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}),
gpu.module(cse),
gpu.module(reconcile-unrealized-casts),
gpu-to-llvm{gpu-binary-annotation=gpu.binary use-bare-pointers-for-host=false use-bare-pointers-for-kernels=false},
mosaic-convert-gpu-to-llvm,
gpu-module-to-binary{format=)" +
mlir::gpu::stringifyCompilationTarget(target).str() + R"(},
convert-math-to-llvm{approximate-log1p=true},
Expand Down Expand Up @@ -152,6 +163,16 @@ void InitContext(mlir::MLIRContext* context) {
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
mlir::registerConvertNVVMToLLVMInterface(registry);
mlir::registerConvertComplexToLLVMInterface(registry);
mlir::registerConvertMemRefToLLVMInterface(registry);
mlir::registerConvertMathToLLVMInterface(registry);
mlir::registerConvertFuncToLLVMInterface(registry);
mlir::index::registerConvertIndexToLLVMInterface(registry);
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this
mlir::arith::registerConvertArithToLLVMInterface(registry);
mlir::registerFinalizeMemRefToLLVMConversionPass();
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry);
mlir::registerBuiltinDialectTranslation(registry);
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/mosaic/gpu/launch_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void buildInitFunction(mlir::OpBuilder &module_builder,
used_smem = builder.create<mlir::LLVM::ConstantOp>(
loc, i32,
builder.getI32IntegerAttr(
mlir::cast<mlir::IntegerAttr>(const_smem.getValue()).getSInt()));
mlir::cast<mlir::IntegerAttr>(const_smem.getValue()).getInt()));
}
}
mlir::Value kernel_handle =
Expand Down
64 changes: 64 additions & 0 deletions jaxlib/mosaic/gpu/pass_boilerplate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_
#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_

#include "mlir/include/mlir/IR/DialectRegistry.h"
#include "mlir/include/mlir/Pass/Pass.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Support/TypeID.h"
namespace mosaic {
namespace gpu {

template <typename Derived, typename Op = void>
class Pass : public ::mlir::OperationPass<Op> {
public:
Pass() : ::mlir::OperationPass<Op>(::mlir::TypeID::get<Derived>()) {}
Pass(const Pass &other) : ::mlir::OperationPass<Op>(other) {}
Pass &operator=(const Pass &) = delete;
Pass(Pass &&) = delete;
Pass &operator=(Pass &&) = delete;
~Pass() = default;

static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral(Derived::kArgumentName);
}
::llvm::StringRef getArgument() const override { return getArgumentName(); }
::llvm::StringRef getDescription() const override { return ""; }
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral(Derived::kPassName);
}
::llvm::StringRef getName() const override { return getPassName(); }
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<Derived>();
}
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<Derived>(*static_cast<const Derived *>(this));
}
void getDependentDialects(::mlir::DialectRegistry &registry) const override {}

private:
using This =
Pass<Derived, Op>; // Can't have a comma in the macro instantiation

public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This)
};

} // namespace gpu
} // namespace mosaic

#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_
81 changes: 81 additions & 0 deletions jaxlib/mosaic/gpu/passes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/mosaic/gpu/passes.h"
#include <memory>
#include <utility>
#include <vector>

#include "llvm/include/llvm/ADT/StringRef.h"
#include "llvm/include/llvm/Support/Debug.h"
#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Transforms/DialectConversion.h"
#include "jaxlib/mosaic/gpu/pass_boilerplate.h"

namespace mosaic {
namespace gpu {

namespace {

class ConvertGpuToLLVMPass
: public mosaic::gpu::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
public:
using mosaic::gpu::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp>::Pass;
static constexpr llvm::StringLiteral kArgumentName =
"mosaic-convert-gpu-to-llvm";
static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass";

void runOnOperation() override {
llvm::DebugFlag = true;
mlir::MLIRContext *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
mlir::LLVMTypeConverter converter(ctx);
mlir::ConversionTarget target(*ctx);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addLegalOp<mlir::gpu::GPUModuleOp>();
target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>(
[&](mlir::gpu::LaunchFuncOp op) -> bool {
return converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes());
});
auto symtab = mlir::SymbolTable(getOperation());
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, "gpu.binary",
false, &symtab);
if (mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))
.failed()) {
signalPassFailure();
}
llvm::DebugFlag = false;
}
};

} // namespace

void registerConvertGpuToLLVMPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return std::make_unique<ConvertGpuToLLVMPass>();
});
}

} // namespace gpu
} // namespace mosaic
27 changes: 27 additions & 0 deletions jaxlib/mosaic/gpu/passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_
#define JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_

namespace mosaic {
namespace gpu {

void registerConvertGpuToLLVMPass();

} // namespace gpu
} // namespace mosaic

#endif // JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_
2 changes: 1 addition & 1 deletion tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def setUp(self):

class TestUtilTest(TestCase):

def test_copy(self):
def test_copy_basic(self):
def kernel(ctx, src, dst, _):
copy(src, dst)
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)
Expand Down

0 comments on commit 3b4039c

Please sign in to comment.