diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 6e24b607cebcf9..477478a4651cee 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1133,6 +1133,10 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> { def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> { let summary = "floating point division remainder operation"; + let description = [{ + Returns the floating point division remainder. + The remainder has the same sign as the dividend (lhs operand). + }]; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index f4781fcb546a3f..0e7beffaf2f421 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1204,7 +1204,10 @@ OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { return constFoldBinaryOp(adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { APFloat result(a); - (void)result.remainder(b); + // APFloat::mod() offers the remainder + // behavior we want, i.e. the result has + // the sign of LHS operand. + (void)result.mod(b); return result; }); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 4fe7cfb689be83..a386a178b78995 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2467,7 +2467,7 @@ func.func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) { // ----- // CHECK-LABEL: @test_remf( -// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32 +// CHECK: %[[res:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: return %[[res]] func.func @test_remf() -> (f32) { %v1 = arith.constant 3.0 : f32 @@ -2476,11 +2476,24 @@ func.func @test_remf() -> (f32) { return %0 : f32 } +// CHECK-LABEL: @test_remf2( +// CHECK: %[[respos:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[resneg:.+]] = arith.constant -1.000000e+00 : f32 +// CHECK: return %[[respos]], %[[resneg]] +func.func @test_remf2() -> (f32, f32) { + %v1 = arith.constant 3.0 : f32 + %v2 = arith.constant -2.0 : f32 + %v3 = arith.constant -3.0 : f32 + %0 = arith.remf %v1, %v2 : f32 + %1 = arith.remf %v3, %v2 : f32 + return %0, %1 : f32, f32 +} + // CHECK-LABEL: @test_remf_vec( // CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32> // CHECK: return %[[res]] func.func @test_remf_vec() -> (vector<4xf32>) { - %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %v1 = arith.constant dense<[1.0, 2.0, -3.0, 4.0]> : vector<4xf32> %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32> %0 = arith.remf %v1, %v2 : vector<4xf32> return %0 : vector<4xf32>