Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "algebraic" fast-math intrinsics, based on fast-math ops that cannot return poison #120718

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,17 +1152,26 @@ fn codegen_regular_intrinsic_call<'tcx>(
ret.write_cvalue(fx, ret_val);
}

sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
sym::fadd_fast
| sym::fsub_fast
| sym::fmul_fast
| sym::fdiv_fast
| sym::frem_fast
| sym::fadd_algebraic
| sym::fsub_algebraic
| sym::fmul_algebraic
| sym::fdiv_algebraic
| sym::frem_algebraic => {
intrinsic_args!(fx, args => (x, y); intrinsic);

let res = crate::num::codegen_float_binop(
fx,
match intrinsic {
sym::fadd_fast => BinOp::Add,
sym::fsub_fast => BinOp::Sub,
sym::fmul_fast => BinOp::Mul,
sym::fdiv_fast => BinOp::Div,
sym::frem_fast => BinOp::Rem,
sym::fadd_fast | sym::fadd_algebraic => BinOp::Add,
sym::fsub_fast | sym::fsub_algebraic => BinOp::Sub,
sym::fmul_fast | sym::fmul_algebraic => BinOp::Mul,
sym::fdiv_fast | sym::fdiv_algebraic => BinOp::Div,
sym::frem_fast | sym::frem_algebraic => BinOp::Rem,
_ => unreachable!(),
},
x,
Expand Down
25 changes: 25 additions & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,31 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
self.frem(lhs, rhs)
}

fn fadd_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
lhs + rhs
}

fn fsub_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
lhs - rhs
}

fn fmul_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
lhs * rhs
}

fn fdiv_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
lhs / rhs
}

fn frem_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
self.frem(lhs, rhs)
}

fn checked_binop(&mut self, oop: OverflowOp, typ: Ty<'_>, lhs: Self::Value, rhs: Self::Value) -> (Self::Value, Self::Value) {
self.gcc_checked_binop(oop, typ, lhs, rhs)
}
Expand Down
48 changes: 44 additions & 4 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,46 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn fadd_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, UNNAMED);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}

fn fsub_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, UNNAMED);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}

fn fmul_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, UNNAMED);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}

fn fdiv_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, UNNAMED);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}

fn frem_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, UNNAMED);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}

fn checked_binop(
&mut self,
oop: OverflowOp,
Expand Down Expand Up @@ -1327,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) }
}
pub fn vector_reduce_fadd_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
llvm::LLVMRustSetFastMath(instr);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}
pub fn vector_reduce_fmul_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
llvm::LLVMRustSetFastMath(instr);
llvm::LLVMRustSetAlgebraicMath(instr);
instr
}
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0);
arith_red!(
simd_reduce_add_unordered: vector_reduce_add,
vector_reduce_fadd_fast,
vector_reduce_fadd_algebraic,
false,
add,
0.0
);
arith_red!(
simd_reduce_mul_unordered: vector_reduce_mul,
vector_reduce_fmul_fast,
vector_reduce_fmul_algebraic,
false,
mul,
1.0
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,7 @@ extern "C" {
) -> &'a Value;

pub fn LLVMRustSetFastMath(Instr: &Value);
pub fn LLVMRustSetAlgebraicMath(Instr: &Value);

// Miscellaneous instructions
pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value;
Expand Down
32 changes: 32 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,38 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}
}
sym::fadd_algebraic
| sym::fsub_algebraic
| sym::fmul_algebraic
| sym::fdiv_algebraic
| sym::frem_algebraic => match float_type_width(arg_tys[0]) {
bjorn3 marked this conversation as resolved.
Show resolved Hide resolved
Some(_width) => match name {
sym::fadd_algebraic => {
bx.fadd_algebraic(args[0].immediate(), args[1].immediate())
}
sym::fsub_algebraic => {
bx.fsub_algebraic(args[0].immediate(), args[1].immediate())
}
sym::fmul_algebraic => {
bx.fmul_algebraic(args[0].immediate(), args[1].immediate())
}
sym::fdiv_algebraic => {
bx.fdiv_algebraic(args[0].immediate(), args[1].immediate())
}
sym::frem_algebraic => {
bx.frem_algebraic(args[0].immediate(), args[1].immediate())
}
_ => bug!(),
},
None => {
bx.tcx().dcx().emit_err(InvalidMonomorphization::BasicFloatType {
span,
name,
ty: arg_tys[0],
});
return Ok(());
}
},

sym::float_to_int_unchecked => {
if float_type_width(arg_tys[0]).is_none() {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,27 @@ pub trait BuilderMethods<'a, 'tcx>:
fn add(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fadd_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fadd_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn sub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fsub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fsub_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fsub_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn mul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fmul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fmul_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fmul_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn udiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn exactudiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn sdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn exactsdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fdiv_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn fdiv_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn urem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn srem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn frem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn frem_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn frem_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn shl(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn lshr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
fn ashr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
Expand Down
12 changes: 11 additions & 1 deletion compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ pub fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -
| sym::variant_count
| sym::is_val_statically_known
| sym::ptr_mask
| sym::debug_assertions => hir::Unsafety::Normal,
| sym::debug_assertions
| sym::fadd_algebraic
| sym::fsub_algebraic
| sym::fmul_algebraic
| sym::fdiv_algebraic
| sym::frem_algebraic => hir::Unsafety::Normal,
_ => hir::Unsafety::Unsafe,
};

Expand Down Expand Up @@ -405,6 +410,11 @@ pub fn check_intrinsic_type(
sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
(1, 0, vec![param(0), param(0)], param(0))
}
sym::fadd_algebraic
| sym::fsub_algebraic
| sym::fmul_algebraic
| sym::fdiv_algebraic
| sym::frem_algebraic => (1, 0, vec![param(0), param(0)], param(0)),
sym::float_to_int_unchecked => (2, 0, vec![param(0)], param(1)),

sym::assume => (0, 0, vec![tcx.types.bool], Ty::new_unit(tcx)),
Expand Down
25 changes: 24 additions & 1 deletion compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,11 @@ extern "C" LLVMAttributeRef LLVMRustCreateMemoryEffectsAttr(LLVMContextRef C,
}
}

// Enable a fast-math flag
// Enable all fast-math flags, including those which will cause floating-point operations
// to return poison for some well-defined inputs. This function can only be used to build
// unsafe Rust intrinsics. That unsafety does permit additional optimizations, but at the
// time of writing, their value is not well-understood relative to those enabled by
// LLVMRustSetAlgebraicMath.
//
// https://llvm.org/docs/LangRef.html#fast-math-flags
extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the fast intrinsics?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to keep them for now and let people play with both variants. I hope that based on experience we can make an argument that the unsafety of the unsafe ones is not worth the optimizations that they unlock, but I have no data to back that up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Expand All @@ -427,6 +431,25 @@ extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
}
}

// Enable fast-math flags which permit algebraic transformations that are not allowed by
// IEEE floating point. For example:
// a + (b + c) = (a + b) + c
// and
// a / b = a * (1 / b)
// Note that this does NOT enable any flags which can cause a floating-point operation on
// well-defined inputs to return poison, and therefore this function can be used to build
// safe Rust intrinsics (such as fadd_algebraic).
//
// https://llvm.org/docs/LangRef.html#fast-math-flags
extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
I->setHasAllowReassoc(true);
I->setHasAllowContract(true);
I->setHasAllowReciprocal(true);
I->setHasNoSignedZeros(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about afn (Approximate functions)? It doesn't poison but isn't mentioned here.

Does the word "algebraic" have a specific meaning here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I see from other places that it does.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about afn (Approximate functions)? It doesn't poison but isn't mentioned here.

Sure, why not. I'll add it.

Does the word "algebraic" have a specific meaning here?

I didn't want to just defang the existing intrinsics because there are some optimizations which rely on assuming that NaN/Inf do not occur. So I needed to come up with a new name, and the way I think about these intrinsics is that unlike IEEE float operations, they permit the usual algebraic transformations. Things like a + (b + b) = (a + b) + c and a / b = a * (1 / b).

I don't think that the name is perfect, and I'd be happy to see someone suggest a better name, but @orlp seems perfectly happy calling them "algebraic".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they permit the usual algebraic transformations. Things like a + (b + b) = (a + b) + c and a / b = a * (1 / b)

That would be a great explanation to put in a comment somewhere :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about afn (Approximate functions)? It doesn't poison but isn't mentioned here.

Sure, why not. I'll add it.

Please don't. Replacing functions by their approximations isn't an algebraically justified optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 And in any case, I don't think it would matter for these intrinsics.

}
}

extern "C" LLVMValueRef
LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
const char *Name, LLVMAtomicOrdering Order) {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,10 @@ symbols! {
f64_nan,
fabsf32,
fabsf64,
fadd_algebraic,
fadd_fast,
fake_variadic,
fdiv_algebraic,
fdiv_fast,
feature,
fence,
Expand All @@ -785,6 +787,7 @@ symbols! {
fmaf32,
fmaf64,
fmt,
fmul_algebraic,
fmul_fast,
fn_align,
fn_delegation,
Expand All @@ -810,6 +813,7 @@ symbols! {
format_unsafe_arg,
freeze,
freg,
frem_algebraic,
frem_fast,
from,
from_desugaring,
Expand All @@ -823,6 +827,7 @@ symbols! {
from_usize,
from_yeet,
fs_create_dir,
fsub_algebraic,
fsub_fast,
fundamental,
future,
Expand Down
40 changes: 40 additions & 0 deletions library/core/src/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,46 @@ extern "rust-intrinsic" {
#[rustc_nounwind]
pub fn frem_fast<T: Copy>(a: T, b: T) -> T;

/// Float addition that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which meaning of "stable" does this use?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way to call this intrinsic is to use the core_intrinsics feature. We do not have a wrapper for these like the atomic intrinsics.

#[rustc_nounwind]
#[rustc_safe_intrinsic]
#[cfg(not(bootstrap))]
pub fn fadd_algebraic<T: Copy>(a: T, b: T) -> T;

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
#[rustc_nounwind]
#[rustc_safe_intrinsic]
#[cfg(not(bootstrap))]
pub fn fsub_algebraic<T: Copy>(a: T, b: T) -> T;

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
#[rustc_nounwind]
#[rustc_safe_intrinsic]
#[cfg(not(bootstrap))]
pub fn fmul_algebraic<T: Copy>(a: T, b: T) -> T;

/// Float division that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
#[rustc_nounwind]
#[rustc_safe_intrinsic]
#[cfg(not(bootstrap))]
pub fn fdiv_algebraic<T: Copy>(a: T, b: T) -> T;

/// Float remainder that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
#[rustc_nounwind]
#[rustc_safe_intrinsic]
#[cfg(not(bootstrap))]
pub fn frem_algebraic<T: Copy>(a: T, b: T) -> T;

/// Convert with LLVM’s fptoui/fptosi, which may return undef for values out of range
/// (<https://github.com/rust-lang/rust/issues/10184>)
///
Expand Down
22 changes: 22 additions & 0 deletions tests/codegen/simd/issue-120720-reduce-nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// compile-flags: -C opt-level=3 -C target-cpu=cannonlake
// only-x86_64

// In a previous implementation, _mm512_reduce_add_pd did the reduction with all fast-math flags
// enabled, making it UB to reduce a vector containing a NaN.

#![crate_type = "lib"]
#![feature(stdarch_x86_avx512, avx512_target_feature)]
use std::arch::x86_64::*;

// CHECK-label: @demo(
#[no_mangle]
#[target_feature(enable = "avx512f")] // Function-level target feature mismatches inhibit inlining
pub unsafe fn demo() -> bool {
// CHECK: %0 = tail call reassoc nsz arcp contract double @llvm.vector.reduce.fadd.v8f64(
// CHECK: %_0.i = fcmp uno double %0, 0.000000e+00
// CHECK: ret i1 %_0.i
let res = unsafe {
_mm512_reduce_add_pd(_mm512_set_pd(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, f64::NAN))
};
res.is_nan()
}
Loading