Skip to content

Commit

Permalink
[Flang][OpenMP][MLIR] Fix common block mapping for regular and declar…
Browse files Browse the repository at this point in the history
…e target link

This PR attempts to fix common block mapping for regular mapping of these types as
well as when they have been marked as "declare target link". This PR should allow correct
mapping of both the members of a common block and the full common block via its
block symbol.

The main changes were some adjustments to the Fortran OpenMP lowering to HLFIR/FIR,
the lowering of the LLVM+OpenMP dialect to LLVM-IR and adjustments to the way the
we handle target kernel map argument rebinding inside of the OMPIRBuilder.

For the Fortran OpenMP lowering were two changes, one to prevent the implicit capture
of common block members when the common block symbol itself has been marked and
the other creates intermediate member access inside of the target region to be used
in-place of those external to the target region, this prevents external usages breaking the
IsolatedFromAbove pact.

In the latter case, there was an adjustment to the size calculation for types to better
handle cases where we pass an array as the type of a map (as opposed to the
bounds and the type of the element), which occurs in the case of common blocks. There
is also some adjustment to how handleDeclareTargetMapVar handles renaming of declare
target symbols in the module to the reference pointer, now it will only apply to those
within the kernel that is currently being generated and we also perform a modification
to replace constants with instructions as necessary as we cannot replace these with our
reference pointer (non-constant and constants do not mix nicely).

In the case of the OpenMPIRBuilder some changes were mde to defer global symbol
rebinding to kernel arguments until all other arguments have been rebound. This
makes sure we do not replace uses that may refer to the global (e.g. a GEP) but are
themselves actually a separate argument that needs bound.

Currently "declare target to" still needs some work, but this may be the case for all
types in conjunction with "declare target to" at the moment.
  • Loading branch information
agozillon committed May 11, 2024
1 parent e3ca558 commit 0c09c3a
Show file tree
Hide file tree
Showing 12 changed files with 607 additions and 36 deletions.
46 changes: 46 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,33 @@ static void genBodyOfTargetDataOp(
genNestedEvaluations(converter, eval);
}

// This generates intermediate common block member accesses within a region
// and then rebinds the members symbol to the intermediate accessors we have
// generated so that subsequent code generation will utilise these instead.
//
// When the scope changes, the bindings to the intermediate accessors should
// be dropped in place of the original symbol bindings.
//
// This is for utilisation with TargetOp.
static void genIntermediateCommonBlockAccessors(
Fortran::lower::AbstractConverter &converter,
const mlir::Location &currentLocation, mlir::Region &region,
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms) {
for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
if (auto *details =
argSymbol->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (auto obj : details->objects()) {
auto targetCBMemberBind = Fortran::lower::genCommonBlockMember(
converter, currentLocation, *obj, region.getArgument(argIndex));
fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj);
fir::ExtendedValue targetCBExv =
getExtendedValue(sexv, targetCBMemberBind);
converter.bindSymbol(*obj, targetCBExv);
}
}
}
}

// This functions creates a block for the body of the targetOp's region. It adds
// all the symbols present in mapSymbols as block arguments to this block.
static void
Expand Down Expand Up @@ -983,6 +1010,18 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,

// Create the insertion point after the marker.
firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());

// If we map a common block using it's symbol e.g. map(tofrom: /common_block/)
// and accessing it's members within the target region, there is a large
// chance we will end up with uses external to the region accessing the common
// block. As target regions are IsolatedFromAbove, we must make sure to
// resolve these, we do so by generating new common block member accesses
// within the region, binding them to the member symbol for the scope of the
// region so that subsequent code generation within the region will utilise
// our new member accesses we have created.
genIntermediateCommonBlockAccessors(converter, currentLocation, region,
mapSyms);

if (genNested)
genNestedEvaluations(converter, eval);
}
Expand Down Expand Up @@ -1574,6 +1613,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
// symbols used inside the region that have not been explicitly mapped using
// the map clause.
auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
// if the symbol is part of an already mapped common block, do not make a
// map for it.
if (const Fortran::semantics::Symbol *common =
Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
if (llvm::find(mapSyms, common) != mapSyms.end())
return;

if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
Expand Down
41 changes: 41 additions & 0 deletions flang/test/Integration/OpenMP/map-types-and-sizes.f90
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,31 @@ subroutine mapType_char
!$omp end target
end subroutine mapType_char

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 35]
subroutine mapType_common_block
implicit none
common /var_common/ var1, var2
integer :: var1, var2
!$omp target map(tofrom: /var_common/)
var1 = var1 + 20
var2 = var2 + 30
!$omp end target
end subroutine mapType_common_block

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [2 x i64] [i64 4, i64 4]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [2 x i64] [i64 35, i64 35]
subroutine mapType_common_block_members
implicit none
common /var_common/ var1, var2
integer :: var1, var2

!$omp target map(tofrom: var1, var2)
var2 = var1
!$omp end target
end subroutine mapType_common_block_members


!CHECK-LABEL: define {{.*}} @{{.*}}maptype_ptr_explicit_{{.*}}
!CHECK: %[[ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
!CHECK: %[[ALLOCA_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOCA]], i32 1
Expand Down Expand Up @@ -346,3 +371,19 @@ end subroutine mapType_char
!CHECK: store ptr %[[ALLOCA]], ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr %[[ARR_OFF]], ptr %[[OFFLOAD_PTR_ARR]], align 8

!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_{{.*}}
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8

!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_members_{{.*}}
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8
!CHECK: %[[BASE_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[BASE_PTR_ARR_1]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 1
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[OFFLOAD_PTR_ARR_1]], align 8
83 changes: 83 additions & 0 deletions flang/test/Lower/OpenMP/common-block-map.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

!CHECK: fir.global common @var_common_(dense<0> : vector<8xi8>) : !fir.array<8xi8>
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : !fir.array<8xi8>

!CHECK-LABEL: func.func @_QPmap_full_block
!CHECK: %[[CB_ADDR:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[CB_ADDR]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common"}
!CHECK: omp.target map_entries(%[[MAP]] -> %[[MAP_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
!CHECK: ^bb0(%[[MAP_ARG]]: !fir.ref<!fir.array<8xi8>>):
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV2:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV2]] {uniq_name = "_QFmap_full_blockEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONV3:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX2:.*]] = arith.constant 4 : index
!CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[CONV3]], %[[INDEX2]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV4:.*]] = fir.convert %[[COORD2]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV4]] {uniq_name = "_QFmap_full_blockEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine map_full_block
implicit none
common /var_common/ var1, var2
integer :: var1, var2
!$omp target map(tofrom: /var_common/)
var1 = var1 + 20
var2 = var2 + 30
!$omp end target
end

!CHECK-LABEL: @_QPmap_mix_of_members
!CHECK: %[[COMMON_BLOCK:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[MAP_EXP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_2]]#0 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var2"}
!CHECK: %[[MAP_IMP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_1]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "var1"}
!CHECK: omp.target map_entries(%[[MAP_EXP]] -> %[[ARG_EXP:.*]], %[[MAP_IMP]] -> %[[ARG_IMP:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
!CHECK: ^bb0(%[[ARG_EXP]]: !fir.ref<i32>, %[[ARG_IMP]]: !fir.ref<i32>):
!CHECK: %[[EXP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_EXP]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[IMP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_IMP]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine map_mix_of_members
implicit none
common /var_common/ var1, var2
integer :: var1, var2

!$omp target map(tofrom: var2)
var2 = var1
!$omp end target
end

!CHECK-LABEL: @_QQmain
!CHECK: %[[DECL_TAR_CB:.*]] = fir.address_of(@var_common_link_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[MAP_DECL_TAR_CB:.*]] = omp.map.info var_ptr(%[[DECL_TAR_CB]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common_link"}
!CHECK: omp.target map_entries(%[[MAP_DECL_TAR_CB]] -> %[[MAP_DECL_TAR_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
!CHECK: ^bb0(%[[MAP_DECL_TAR_ARG]]: !fir.ref<!fir.array<8xi8>>):
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[MEMBER_ONE:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[MEMBER_TWO:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
program main
implicit none
common /var_common_link/ link1, link2
integer :: link1, link2
!$omp declare target link(/var_common_link/)

!$omp target map(tofrom: /var_common_link/)
link1 = link2 + 20
!$omp end target
end program
5 changes: 5 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,11 @@ class OpenMPIRBuilder {
int32_t UB);
///}

/// Replaces constant values with instruction equivelants where possible
/// inside of a function.
static void replaceConstantValueUsesInFuncWithInstr(llvm::Value *Input,
Function *Func);

private:
// Sets the function attributes expected for the outlined function
void setOutlinedTargetRegionFunctionAttributes(Function *OutlinedFn);
Expand Down
74 changes: 54 additions & 20 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5085,8 +5085,8 @@ static void replaceConstatExprUsesInFuncWithInstr(ConstantExpr *ConstExpr,
}
}

static void replaceConstantValueUsesInFuncWithInstr(llvm::Value *Input,
Function *Func) {
void OpenMPIRBuilder::replaceConstantValueUsesInFuncWithInstr(
llvm::Value *Input, Function *Func) {
for (User *User : make_early_inc_range(Input->users()))
if (auto *Const = dyn_cast<Constant>(User))
if (auto *ConstExpr = dyn_cast<ConstantExpr>(Const))
Expand Down Expand Up @@ -5160,6 +5160,31 @@ static Function *createOutlinedFunction(
? make_range(Func->arg_begin() + 1, Func->arg_end())
: Func->args();

auto ReplaceValue = [&OMPBuilder](Value *Input, Value *InputCopy,
Function *Func) {
// Things like GEP's can come in the form of Constants. Constants and
// ConstantExpr's do not have access to the knowledge of what they're
// contained in, so we must dig a little to find an instruction so we
// can tell if they're used inside of the function we're outlining. We
// also replace the original constant expression with a new instruction
// equivalent; an instruction as it allows easy modification in the
// following loop, as we can now know the constant (instruction) is
// owned by our target function and replaceUsesOfWith can now be invoked
// on it (cannot do this with constants it seems). A brand new one also
// allows us to be cautious as it is perhaps possible the old expression
// was used inside of the function but exists and is used externally
// (unlikely by the nature of a Constant, but still).
OMPBuilder.replaceConstantValueUsesInFuncWithInstr(Input, Func);

// Collect all the instructions
for (User *User : make_early_inc_range(Input->users()))
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
};

SmallVector<std::pair<Value *, Value *>> DeferredReplacement;

// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, ArgRange)) {
Value *Input = std::get<0>(InArg);
Expand All @@ -5169,27 +5194,36 @@ static Function *createOutlinedFunction(
Builder.restoreIP(
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));

// Things like GEP's can come in the form of Constants. Constants and
// ConstantExpr's do not have access to the knowledge of what they're
// contained in, so we must dig a little to find an instruction so we can
// tell if they're used inside of the function we're outlining. We also
// replace the original constant expression with a new instruction
// equivalent; an instruction as it allows easy modification in the
// following loop, as we can now know the constant (instruction) is owned by
// our target function and replaceUsesOfWith can now be invoked on it
// (cannot do this with constants it seems). A brand new one also allows us
// to be cautious as it is perhaps possible the old expression was used
// inside of the function but exists and is used externally (unlikely by the
// nature of a Constant, but still).
replaceConstantValueUsesInFuncWithInstr(Input, Func);
// In certain cases a Global may be set up for replacement, however, this
// Global may be used in multiple arguments to the kernel, just segmented
// apart, for example, if we have a global array, that is sectioned into
// multiple mappings (technically not legal in OpenMP, but there is a case
// in Fortran for Common Blocks where this is neccesary), we will end up
// with GEP's into this array inside the kernel, that refer to the Global
// but are technically seperate arguments to the kernel for all intents and
// purposes. If we have mapped a segment that requires a GEP into the 0-th
// index, it will fold into an referal to the Global, if we then encounter
// this folded GEP during replacement all of the references to the
// Global in the kernel will be replaced with the argument we have generated
// that corresponds to it, including any other GEP's that refer to the
// Global that may be other arguments. This will invalidate all of the other
// preceding mapped arguments that refer to the same global that may be
// seperate segments. To prevent this, we defer global processing until all
// other processing has been performed.
if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
continue;
}

// Collect all the instructions
for (User *User : make_early_inc_range(Input->users()))
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
ReplaceValue(Input, InputCopy, Func);
}

// Replace all of our deferred Input values, currently just Globals.
for (auto Deferred : DeferredReplacement)
ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);

// Restore insert point.
Builder.restoreIP(OldInsertPoint);

Expand Down
Loading

0 comments on commit 0c09c3a

Please sign in to comment.