Skip to content

Commit

Permalink
Pdll changes
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Sep 27, 2024
1 parent b1a898e commit db785cf
Showing 1 changed file with 82 additions and 3 deletions.
85 changes: 82 additions & 3 deletions lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

Rewrite GetUser(op:Op<>)->Op;

Constraint ValidateOp(op:Op<>,input0:Op<>, input1:Op<>, input2:Op<>);

Rewrite BuildOpWithBetaZero(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, betaZero:Op<>)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);

Rewrite BuildOp(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);
Expand All @@ -13,7 +15,19 @@ Rewrite BuildOpWithBetaZeroAndBiasRelu(op:Op<vector.contract>,input0:Op<>, input

Rewrite BuildOpWithBiasRelu(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, addf:Op<vector.transfer_write>, maxf:Op<vector.transfer_write>)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);

Rewrite BuildTranspose(op:Op<vector.transpose>, input0:Op<>, output: Type)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);
Rewrite BuildTranspose(op:Op<>, input0:Op<>, output: Type)->(dispatch:Op<func.callOp>, invoke:Op<func.callOp>);

Constraint ValidateTranspose(op:Op<>, input0:Op<>, output:Type);

Rewrite TileConfigInsertionWithBetaZeroFusedBrgemm(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, betaZero:Op<>, addf:Op<vector.transfer_write>, maxf:Op<vector.transfer_write>)->(alloca:Op<func.callOp>, amxTileConfigSetup:Op<func.callOp>, amxTileConfigReset:Op<func.callOp>);

Rewrite TileConfigInsertionWithFusedBrgemm(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, betaZero:Op<>, addf:Op<vector.transfer_write>, maxf:Op<vector.transfer_write>)->(alloca:Op<func.callOp>, amxTileConfigSetup:Op<func.callOp>, amxTileConfigReset:Op<func.callOp>);

Rewrite TileConfigInsertionWithBetaZeroBrgemm(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, betaZero:Op<>)->(alloca:Op<func.callOp>, amxTileConfigSetup:Op<func.callOp>, amxTileConfigReset:Op<func.callOp>);

Rewrite TileConfigInsertionWithBrgemm(op:Op<vector.contract>,input0:Op<>, input1:Op<>, input2:Op<>, betaZero:Op<>)->(alloca:Op<func.callOp>, amxTileConfigSetup:Op<func.callOp>, amxTileConfigReset:Op<func.callOp>);

Rewrite MoveInstructions(alloca:Op<func.callOp>, amxTileConfigSetup:Op<func.callOp>, amxTileConfigReset:Op<func.callOp>);

Pattern ConvertContractToFusedBrgemmWithBetaZero with benefit(15000),recursion{
let input2 = op<vector.transfer_read>(alloc2:Value, indices2:ValueRange, const0:Value, constIndex2:ValueRange)->(output2:TypeRange);
Expand All @@ -28,6 +42,8 @@ Pattern ConvertContractToFusedBrgemmWithBetaZero with benefit(15000),recursion{
let reluRead = op<vector.transfer_read>(alloc2, indices6:ValueRange, const0, bounds4:ValueRange)->(reluOutput:TypeRange);
let maxf = op<arith.maximumf>(reluRead, cst)->(typeRange);
let maxfResult = op<vector.transfer_write>(maxf, alloc2, indices7:ValueRange, bounds5:ValueRange)->(typeRange);
ValidateOp(root, input0, input1, input2);

rewrite root with{
let replacement = BuildOpWithBetaZeroAndBiasRelu(root, input0, input1, input2, betaZero, addfResult, maxfResult);
replace root with (replacement.dispatch, replacement.invoke);
Expand All @@ -51,6 +67,8 @@ Pattern ConvertContractToFusedBrgemm with benefit(10000),recursion{
let reluRead = op<vector.transfer_read>(alloc2, indices6:ValueRange, const0, bounds4:ValueRange)->(reluOutput:TypeRange);
let maxf = op<arith.maximumf>(reluRead, cst:Value)->(typeRange);
let maxfResult = op<vector.transfer_write>(maxf, alloc2, indices7:ValueRange, bounds5:ValueRange)->(typeRange);
ValidateOp(root, input0, input1, input2);

rewrite root with{
let replacement = BuildOpWithBiasRelu(root, input0, input1, input2, addfResult, maxfResult);
replace root with (replacement.dispatch, replacement.invoke);
Expand All @@ -66,8 +84,10 @@ Pattern ConvertContractToBrgemmWithBetaZero with benefit(5000),recursion{
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2)->(output:TypeRange);
let cst = op<arith.constant>()->(constantVector:AnyVector);
let betaZero = op<vector.transfer_write>(cst, alloc2, input3:ValueRange, input4:ValueRange);
ValidateOp(root, input0, input1, input2);

rewrite root with{
let replacement = BuildOpWithBetaZero(root, input0, input1, input2, betaZero);
let replacement = BuildOpWithBetaZero(root, input0, input1, input2, betaZero);
replace root with (replacement.dispatch, replacement.invoke);
let user = GetUser(replacement.dispatch);
erase user;
Expand All @@ -77,8 +97,9 @@ Pattern ConvertContractToBrgemmWithBetaZero with benefit(5000),recursion{

Pattern ConvertContractToBrgemm with benefit(0),recursion{
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2:Op<>)->(output:TypeRange);
ValidateOp(root, input0, input1, input2);
rewrite root with{
let replacement = BuildOp(root, input0, input1, input2);
let replacement = BuildOp(root, input0, input1, input2);
replace root with (replacement.dispatch, replacement.invoke);
let user = GetUser(replacement.dispatch);
erase user;
Expand All @@ -87,6 +108,7 @@ Pattern ConvertContractToBrgemm with benefit(0),recursion{

Pattern ConvertTranspose with benefit(0),recursion{
let transpose = op<vector.transpose>(input0:Op<>)->(transposeOutput0:Type);
ValidateTranspose(transpose, input0, transposeOutput0);
rewrite transpose with{
let replacement = BuildTranspose(transpose, input0, transposeOutput0);
replace transpose with (replacement.dispatch, replacement.invoke);
Expand All @@ -95,4 +117,61 @@ Pattern ConvertTranspose with benefit(0),recursion{
};
}

Pattern AMXTileConfigInsertionWithBetaZeroFusedBrgemm with benefit(15000){
let input2 = op<vector.transfer_read>(alloc2:Value, indices2:ValueRange, const0:Value, constIndex2:ValueRange)->(output2:TypeRange);
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2)->(output:TypeRange);
let cst = op<arith.constant>()->(constantVector:AnyVector);
let betaZero = op<vector.transfer_write>(cst, alloc2, input3:ValueRange, input4:ValueRange);
let biasRead = op<vector.transfer_read>(bias:Value, indices3:ValueRange, const0, bounds1:ValueRange)->(biasOutput:TypeRange);
let biasBcast = op<vector.broadcast>(biasRead);
let biasTRead = op<vector.transfer_read>(alloc2, indices4:ValueRange, const0, bounds2:ValueRange)->(typeRange:TypeRange);
let addf = op<arith.addf>(biasBcast, biasTRead)->(typeRange);
let addfResult = op<vector.transfer_write>(addf, alloc2, indices5:ValueRange, bounds3:ValueRange)->(typeRange);
let reluRead = op<vector.transfer_read>(alloc2, indices6:ValueRange, const0, bounds4:ValueRange)->(reluOutput:TypeRange);
let maxf = op<arith.maximumf>(reluRead, cst)->(typeRange);
let maxfResult = op<vector.transfer_write>(maxf, alloc2, indices7:ValueRange, bounds5:ValueRange)->(typeRange);
rewrite root with{
let replacement = TileConfigInsertionWithBetaZeroFusedBrgemm(root, input0, input1, input2, betaZero, addfResult, maxfResult);
replace root with (replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
MoveInstructions(replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
}
}

Pattern AMXTileConfigInsertionWithFusedBrgemm with benefit(10000){
let input2 = op<vector.transfer_read>(alloc2:Value, indices2:ValueRange, const0:Value, constIndex2:ValueRange)->(output2:TypeRange);
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2)->(output:TypeRange);
let biasRead = op<vector.transfer_read>(bias:Value, indices3:ValueRange, const0, bounds:ValueRange)->(biasOutput:TypeRange);
let biasBcast = op<vector.broadcast>(biasRead);
let biasTRead = op<vector.transfer_read>(alloc2, indices4:ValueRange, const0, bounds2:ValueRange)->(typeRange:TypeRange);
let addf = op<arith.addf>(biasBcast, biasTRead)->(typeRange);
let addfResult = op<vector.transfer_write>(addf, alloc2, indices5:ValueRange, bounds3:ValueRange)->(typeRange);
let reluRead = op<vector.transfer_read>(alloc2, indices6:ValueRange, const0, bounds4:ValueRange)->(reluOutput:TypeRange);
let maxf = op<arith.maximumf>(reluRead, cst:Value)->(typeRange);
let maxfResult = op<vector.transfer_write>(maxf, alloc2, indices7:ValueRange, bounds5:ValueRange)->(typeRange);
rewrite root with{
let replacement = TileConfigInsertionWithFusedBrgemm(root, input0, input1, input2, addfResult, maxfResult);
replace root with (replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
MoveInstructions(replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
}
}

Pattern AMXTileConfigInsertionWithBetaZeroBrgemm with benefit(5000){
let input2 = op<vector.transfer_read>(alloc2:Value, indices2:ValueRange, const2:Value, constIndex2:ValueRange)->(output2:TypeRange);
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2)->(output:TypeRange);
let cst = op<arith.constant>()->(constantVector:AnyVector);
let betaZero = op<vector.transfer_write>(cst, alloc2, input3:ValueRange, input4:ValueRange);
rewrite root with{
let replacement = TileConfigInsertionWithBetaZeroBrgemm(root, input0, input1, input2, betaZero);
replace root with (replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
MoveInstructions(replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
}
}

Pattern AMXTileConfigInsertionWithBrgemm with benefit(0){
let root = op<vector.contract>(input0:Op<>, input1:Op<>, input2:Op<>)->(output:TypeRange);
rewrite root with{
let replacement = TileConfigInsertionWithBrgemm(root, input0, input1, input2);
replace root with (replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
MoveInstructions(replacement.alloca, replacement.amxTileConfigSetup, replacement.amxTileConfigReset);
}
}

0 comments on commit db785cf

Please sign in to comment.