Skip to content

Commit 423d8ec

Browse files
committed
Add autocasts for bf16 and bf16xN
1 parent 2d549e7 commit 423d8ec

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
370370
}
371371

372372
match self.type_kind(llvm_ty) {
373+
TypeKind::BFloat => rust_ty == self.type_i16(),
374+
373375
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
374376
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
375377
// as, well, packed structs, so they won't match with those either)
@@ -387,11 +389,18 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
387389
},
388390
)
389391
}
390-
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
392+
TypeKind::Vector => {
391393
let element_count = self.vector_length(llvm_ty) as u64;
392-
let int_width = element_count.next_power_of_two().max(8);
394+
let llvm_element_ty = self.element_type(llvm_ty);
393395

394-
rust_ty == self.type_ix(int_width)
396+
if llvm_element_ty == self.type_bf16() {
397+
rust_ty == self.type_vector(self.type_i16(), element_count)
398+
} else if llvm_element_ty == self.type_i1() {
399+
let int_width = element_count.next_power_of_two().max(8);
400+
rust_ty == self.type_ix(int_width)
401+
} else {
402+
false
403+
}
395404
}
396405
_ => false,
397406
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17611761
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
17621762
}
17631763
}
1764-
_ => unreachable!(),
1764+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
17651765
}
17661766
}
17671767

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,9 @@ unsafe extern "C" {
963963
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
964964
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
965965

966+
// Operations on non-IEEE real types
967+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
968+
966969
// Operations on function types
967970
pub(crate) fn LLVMFunctionType<'a>(
968971
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
174174
)
175175
}
176176
}
177+
178+
pub(crate) fn type_bf16(&self) -> &'ll Type {
179+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
180+
}
177181
}
178182

179183
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -247,7 +251,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
247251

248252
fn float_width(&self, ty: &'ll Type) -> usize {
249253
match self.type_kind(ty) {
250-
TypeKind::Half => 16,
254+
TypeKind::Half | TypeKind::BFloat => 16,
251255
TypeKind::Float => 32,
252256
TypeKind::Double => 64,
253257
TypeKind::X86_FP80 => 80,

tests/codegen-llvm/inject-autocast.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
55
#![crate_type = "lib"]
66

7-
use std::simd::i64x2;
7+
use std::simd::{f32x4, i16x8, i64x2};
88

99
#[repr(simd)]
1010
pub struct Tile([i8; 1024]);
@@ -36,6 +36,19 @@ pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
3636
foo(a, b)
3737
}
3838

39+
// CHECK-LABEL: @bf16_vector_autocast
40+
#[no_mangle]
41+
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
42+
extern "unadjusted" {
43+
#[link_name = "llvm.x86.vcvtneps2bf16128"]
44+
fn foo(a: f32x4) -> i16x8;
45+
}
46+
47+
// CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}})
48+
// CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16>
49+
foo(a)
50+
}
51+
3952
// CHECK-LABEL: @struct_autocast
4053
#[no_mangle]
4154
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
@@ -77,6 +90,8 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
7790

7891
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
7992

93+
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
94+
8095
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
8196

8297
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)