Skip to content

Commit

Permalink
Merge pull request #589 from jjfumero/feat/spirv/unroll
Browse files Browse the repository at this point in the history
Enable Partial Loop Unroll for all Backends
  • Loading branch information
jjfumero authored Nov 12, 2024
2 parents 0db1cd3 + cd51fd2 commit 73d1dfb
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 57 deletions.
2 changes: 2 additions & 0 deletions tornado-assembly/src/etc/exportLists/common-exports
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@
--add-exports jdk.internal.vm.compiler/org.graalvm.compiler.nodes.virtual=tornado.drivers.common
--add-exports jdk.internal.vm.compiler/org.graalvm.compiler.loop.phases=tornado.drivers.common
--add-exports jdk.internal.vm.compiler/org.graalvm.compiler.core.common.util=tornado.drivers.common
--add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.tiers=tornado.drivers.common
--add-exports jdk.internal.vm.compiler/org.graalvm.compiler.phases.common=tornado.drivers.common
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@

import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.loop.LoopEx;
import org.graalvm.compiler.nodes.loop.LoopFragmentInside;
import org.graalvm.compiler.nodes.loop.LoopsData;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
import org.graalvm.compiler.phases.tiers.MidTierContext;

import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoLoopsData;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoMidTierContext;

/**
* Applies partial unroll on counted loops of more than 128 elements. By default,
Expand All @@ -44,27 +45,45 @@
* @see org.graalvm.compiler.loop.phases.LoopTransformations
*/

public class TornadoPartialLoopUnroll extends BasePhase<MidTierContext> {
public class TornadoPartialLoopUnrollPhase extends BasePhase<MidTierContext> {

private static final int LOOP_UNROLL_FACTOR_DEFAULT = 2;
private static final int LOOP_BOUND_UPPER_LIMIT = 16384;

private static final int GRAPH_NODES_UPPER_LIMIT = 40000;

private static void partialUnroll(StructuredGraph graph, MidTierContext context) {
final LoopsData dataCounted = new TornadoLoopsData(graph);
private enum OptimizationStatus {
SUCCESS, //
ERROR;
}

private static OptimizationStatus partialUnroll(StructuredGraph graph, MidTierContext context) {

LoopsData dataCounted;
try {
dataCounted = new TornadoLoopsData(graph);
} catch (NullPointerException nullPointerException) {
return OptimizationStatus.ERROR;
}

CanonicalizerPhase canonicalizer = CanonicalizerPhase.create();

canonicalizer.apply(graph, context);
dataCounted.detectCountedLoops();
for (LoopEx loop : dataCounted.countedLoops()) {
int loopBound = loop.counted().getLimit().asJavaConstant().asInt();
if (isPowerOfTwo(loopBound) && (loopBound < LOOP_BOUND_UPPER_LIMIT)) {
LoopFragmentInside newSegment = loop.inside().duplicate();
newSegment.insertWithinAfter(loop, null);
}
try {
dataCounted.countedLoops().forEach(loop -> {
int loopBound = loop.counted().getLimit().asJavaConstant().asInt();
if (isPowerOfTwo(loopBound) && (loopBound < LOOP_BOUND_UPPER_LIMIT)) {
LoopFragmentInside loopBody = loop.inside().duplicate();
loopBody.insertWithinAfter(loop, null);
}
});

new DeadCodeEliminationPhase().apply(graph);
} catch (NullPointerException runtimeException) {
return OptimizationStatus.ERROR;
}
new DeadCodeEliminationPhase().apply(graph);
return OptimizationStatus.SUCCESS;
}

private static int getUnrollFactor() {
Expand All @@ -84,19 +103,33 @@ public Optional<NotApplicable> notApplicableTo(GraphState graphState) {
return ALWAYS_APPLICABLE;
}

private StructuredGraph checkStatus(StructuredGraph graph, StructuredGraph snapshot, OptimizationStatus status) {
return status != OptimizationStatus.SUCCESS ? snapshot : graph;
}

@Override
protected void run(StructuredGraph graph, MidTierContext context) {

TornadoMidTierContext tornadoMidTierContext = (TornadoMidTierContext) context;
if (!tornadoMidTierContext.getMeta().applyPartialLoopUnroll()) {
return;
}

if (!graph.hasLoops()) {
return;
}

int initialNodeCount = graph.getNodeCount();
int unrollFactor = getUnrollFactor();

StructuredGraph snapshot = (StructuredGraph) graph.copy(TornadoCoreRuntime.getDebugContext());
for (int i = 0; Math.pow(2, i) < unrollFactor; i++) {
if (graph.getNodeCount() < getUpperGraphLimit(initialNodeCount)) {
partialUnroll(graph, context);
OptimizationStatus status = partialUnroll(graph, context);
graph = checkStatus(graph, snapshot, status);
if (status != OptimizationStatus.SUCCESS) {
return;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.BoundCheckEliminationPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionCheckingElimination;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnroll;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnrollPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPanamaSegmentsHeaderPhase;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoFloatingReadReplacement;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
Expand Down Expand Up @@ -75,7 +75,7 @@ public OCLMidTier(OptionValues options) {
appendPhase(canonicalizer);

if (TornadoOptions.isPartialUnrollEnabled()) {
appendPhase(new TornadoPartialLoopUnroll());
appendPhase(new TornadoPartialLoopUnrollPhase());
}

appendPhase(new MidTierLoweringPhase(canonicalizer));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@

import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.BoundCheckEliminationPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionCheckingElimination;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnrollPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPanamaSegmentsHeaderPhase;
import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoFloatingReadReplacement;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.graal.compiler.TornadoMidTier;

public class PTXMidTier extends TornadoMidTier {
Expand Down Expand Up @@ -68,6 +70,10 @@ public PTXMidTier(OptionValues options) {

appendPhase(canonicalizer);

if (TornadoOptions.isPartialUnrollEnabled()) {
appendPhase(new TornadoPartialLoopUnrollPhase());
}

appendPhase(new MidTierLoweringPhase(canonicalizer));

appendPhase(new FrameStateAssignmentPhase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
Expand Down Expand Up @@ -229,13 +229,13 @@ public Variable emitConditionalMove(PlatformKind cmpKind, Value leftVal, Value r
* based on a bitwise and operation between two values.
*
* @param leftVal
* the left value of a condition
* the left value of a condition
* @param right
* the right value of a condition
* the right value of a condition
* @param trueValue
* the true value to move in the result
* the true value to move in the result
* @param falseValue
* the false value to move in the result
* the false value to move in the result
* @return Variable: reference to the variable that contains the result
*/
@Override
Expand Down Expand Up @@ -387,8 +387,8 @@ public SPIRVArithmeticTool getArithmetic() {
return (SPIRVArithmeticTool) super.getArithmetic();
}

public void emitConditionalBranch(Value condition, LabelRef trueBranch, LabelRef falseBranch) {
append(new SPIRVControlFlow.BranchConditional(condition, trueBranch, falseBranch));
public void emitConditionalBranch(Value condition, LabelRef trueBranch, LabelRef falseBranch, int unrollFactor) {
append(new SPIRVControlFlow.BranchConditional(condition, trueBranch, falseBranch, unrollFactor));
}

public void emitJump(LabelRef label, boolean isLoopEdgeBack) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,29 @@
import uk.ac.manchester.tornado.drivers.common.compiler.phases.utils.DumpLowTierGraph;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.OCLFPGAPragmaPhase;
import uk.ac.manchester.tornado.drivers.opencl.graal.phases.OCLFPGAThreadScheduler;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.*;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.InverseSquareRootPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.PartialLoopUnrollPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.SPIRVFMAPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.SPIRVFP64SupportPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoFixedArrayCopyPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoHalfFloatVectorOffset;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.graal.compiler.TornadoLowTier;

public class SPIRVLowTier extends TornadoLowTier {

public SPIRVLowTier(OptionValues options, TornadoDeviceContext deviceContext, AddressLoweringByNodePhase.AddressLowering addressLowering) {
CanonicalizerPhase canonicalizer = getCannonicalizer(options);

CanonicalizerPhase canonicalizer = getCannonicalizer();

appendPhase(new SPIRVFP64SupportPhase(deviceContext));

appendPhase(new LowTierLoweringPhase(canonicalizer));

if (TornadoOptions.ENABLE_SPIRV_LOOP_UNROLL) {
appendPhase(new PartialLoopUnrollPhase());
}

if (ConditionalElimination.getValue(options)) {
appendPhase(new IterativeConditionalEliminationPhase(canonicalizer, false));
}
Expand Down Expand Up @@ -100,7 +110,7 @@ public SPIRVLowTier(OptionValues options, TornadoDeviceContext deviceContext, Ad
}
}

private CanonicalizerPhase getCannonicalizer(OptionValues options) {
private CanonicalizerPhase getCannonicalizer() {
return CanonicalizerPhase.create();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
import org.graalvm.compiler.phases.common.ReassociationPhase;
import org.graalvm.compiler.phases.common.RemoveValueProxyPhase;

import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnroll;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.BoundCheckEliminationPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.guards.ExceptionCheckingElimination;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoPartialLoopUnrollPhase;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPanamaSegmentsHeaderPhase;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoFloatingReadReplacement;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
Expand Down Expand Up @@ -88,7 +88,7 @@ public SPIRVMidTier(OptionValues options) {
appendPhase(canonicalizer);

if (TornadoOptions.isPartialUnrollEnabled()) {
appendPhase(new TornadoPartialLoopUnroll());
appendPhase(new TornadoPartialLoopUnrollPhase());
}

appendPhase(new MidTierLoweringPhase(canonicalizer));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
Expand All @@ -36,6 +36,7 @@
import java.util.Map;

import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.cfg.BasicBlock;
import org.graalvm.compiler.core.common.cfg.BlockMap;
import org.graalvm.compiler.core.common.type.ObjectStamp;
import org.graalvm.compiler.core.common.type.Stamp;
Expand All @@ -54,12 +55,14 @@
import org.graalvm.compiler.lir.gen.LIRGenerator;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool.BlockScope;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractEndNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.BeginNode;
import org.graalvm.compiler.nodes.BreakpointNode;
import org.graalvm.compiler.nodes.DirectCallTargetNode;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.IndirectCallTargetNode;
import org.graalvm.compiler.nodes.Invoke;
Expand Down Expand Up @@ -107,7 +110,7 @@
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind;
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVLIRStmt;
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.PragmaUnrollNode;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.PartialUnrollNode;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.ThreadConfigurationNode;
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.vector.SPIRVVectorValueNode;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
Expand Down Expand Up @@ -520,12 +523,27 @@ public void emitIf(final IfNode x) {
final boolean isLoop = gen.getCurrentBlock().isLoopHeader();
final boolean isNegated = isLoop && x.trueSuccessor() instanceof LoopExitNode;

final Variable condition = emitLogicNode(x.condition());
boolean isConditionFromParallelLoop = false;
int unrollFactor = 0;
if (isLoop) {
getGen().emitConditionalBranch(condition, getLIRBlock(x.trueSuccessor()), getLIRBlock(x.falseSuccessor()));
} else {
getGen().emitConditionalBranch(condition, getLIRBlock(x.trueSuccessor()), getLIRBlock(x.falseSuccessor()));
BasicBlock<?> block = gen.getCurrentBlock();
HIRBlock hirBlock = (HIRBlock) block;
AbstractBeginNode beginNode = hirBlock.getBeginNode();
if (beginNode instanceof LoopBeginNode loopBeginNode) {
// Once pragma is inserted, it is easier to analyze if partial unroll is possible
FixedNode successor = loopBeginNode.next();
if (successor instanceof PartialUnrollNode partialUnrollNode) {
isConditionFromParallelLoop = true;
unrollFactor = partialUnrollNode.getPartialUnrollFactor();
}
}
if (!isConditionFromParallelLoop) {
Logger.traceBuildLIR(Logger.BACKEND.SPIRV, "emitLoopUnroll");
}
}

final Variable condition = emitLogicNode(x.condition());
getGen().emitConditionalBranch(condition, getLIRBlock(x.trueSuccessor()), getLIRBlock(x.falseSuccessor()), unrollFactor);
}

@Override
Expand Down Expand Up @@ -744,7 +762,7 @@ protected void emitNode(final ValueNode node) {
emitLoopExit((LoopExitNode) node);
} else if (node instanceof ShortCircuitOrNode) {
throw new RuntimeException("Unimplemented");
} else if (node instanceof PragmaUnrollNode || node instanceof ThreadConfigurationNode) {
} else if (node instanceof PartialUnrollNode || node instanceof ThreadConfigurationNode) {
// ignore emit-action
} else {
super.emitNode(node);
Expand Down
Loading

0 comments on commit 73d1dfb

Please sign in to comment.