Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer Storage Classes by "specializing" SPIR-V modules "generic" over them. #414

Merged
merged 16 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repository = "https://github.com/EmbarkStudios/rust-gpu"

[lib]
crate-type = ["dylib"]
test = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do? there are a few tests in rustc_codegen_spirv/linker/test.rs, does this disable those?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I... this is an accident, sorry! I was trying to debug reducing my RLS waiting time (it didn't work sigh) and accidentally included it in a commit apparently (wouldn't have happened if I was still using git gui oh well).


[features]
# By default, the use-compiled-tools is enabled, as doesn't require additional
Expand All @@ -33,6 +34,7 @@ rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "279cc519166b6a0b
rustc-demangle = "0.1.18"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
smallvec = "1.6.1"
spirv-tools = { version = "0.4.0", default-features = false }
tar = "0.4.30"
topological-sort = "0.1"
Expand Down
89 changes: 30 additions & 59 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@ use std::fmt;
/// tracking.
#[derive(Default)]
pub struct RecursivePointeeCache<'tcx> {
map: RefCell<HashMap<(PointeeTy<'tcx>, StorageClass), PointeeDefState>>,
map: RefCell<HashMap<PointeeTy<'tcx>, PointeeDefState>>,
}

impl<'tcx> RecursivePointeeCache<'tcx> {
fn begin(
&self,
cx: &CodegenCx<'tcx>,
span: Span,
pointee: PointeeTy<'tcx>,
storage_class: StorageClass,
) -> Option<Word> {
// Warning: storage_class must match the one called with end()
match self.map.borrow_mut().entry((pointee, storage_class)) {
fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
match self.map.borrow_mut().entry(pointee) {
// State: This is the first time we've seen this type. Record that we're beginning to translate this type,
// and start doing the translation.
Entry::Vacant(entry) => {
Expand All @@ -52,7 +45,11 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
// emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
PointeeDefState::Defining => {
let new_id = cx.emit_global().id();
cx.emit_global().type_forward_pointer(new_id, storage_class);
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
cx.emit_global()
.type_forward_pointer(new_id, StorageClass::Generic);
entry.insert(PointeeDefState::DefiningWithForward(new_id));
if !cx.builder.has_capability(Capability::Addresses)
&& !cx
Expand Down Expand Up @@ -81,19 +78,16 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
cx: &CodegenCx<'tcx>,
span: Span,
pointee: PointeeTy<'tcx>,
storage_class: StorageClass,
pointee_spv: Word,
) -> Word {
// Warning: storage_class must match the one called with begin()
match self.map.borrow_mut().entry((pointee, storage_class)) {
match self.map.borrow_mut().entry(pointee) {
// We should have hit begin() on this type already, which always inserts an entry.
Entry::Vacant(_) => bug!("RecursivePointeeCache::end should always have entry"),
Entry::Occupied(mut entry) => match *entry.get() {
// State: There have been no recursive references to this type while defining it, and so no
// OpTypeForwardPointer has been emitted. This is the most common case.
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
storage_class,
pointee: pointee_spv,
}
.def(span, cx);
Expand All @@ -105,7 +99,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::DefiningWithForward(id) => {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
storage_class,
pointee: pointee_spv,
}
.def_with_id(cx, span, id)
Expand Down Expand Up @@ -261,11 +254,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Cast(cast_target) => cast_target.spirv_type(span, cx),
PassMode::Indirect { .. } => {
let pointee = self.ret.layout.spirv_type(span, cx);
let pointer = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee,
}
.def(span, cx);
let pointer = SpirvType::Pointer { pointee }.def(span, cx);
// Important: the return pointer comes *first*, not last.
argument_types.push(pointer);
SpirvType::Void.def(span, cx)
Expand Down Expand Up @@ -304,11 +293,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
extra_attrs: None, ..
} => {
let pointee = arg.layout.spirv_type(span, cx);
SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee,
}
.def(span, cx)
SpirvType::Pointer { pointee }.def(span, cx)
}
};
argument_types.push(arg_type);
Expand Down Expand Up @@ -465,26 +450,20 @@ fn trans_scalar<'tcx>(
Primitive::F32 => SpirvType::Float(32).def(span, cx),
Primitive::F64 => SpirvType::Float(64).def(span, cx),
Primitive::Pointer => {
let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index);
// Default to function storage class.
let storage_class = storage_class.unwrap_or(StorageClass::Function);
let pointee_ty = dig_scalar_pointee(cx, ty, index);
// Pointers can be recursive. So, record what we're currently translating, and if we're already translating
// the same type, emit an OpTypeForwardPointer and use that ID.
if let Some(predefined_result) =
cx.type_cache
.recursive_pointee_cache
.begin(cx, span, pointee_ty, storage_class)
if let Some(predefined_result) = cx
.type_cache
.recursive_pointee_cache
.begin(cx, span, pointee_ty)
{
predefined_result
} else {
let pointee = pointee_ty.spirv_type(span, cx);
cx.type_cache.recursive_pointee_cache.end(
cx,
span,
pointee_ty,
storage_class,
pointee,
)
cx.type_cache
.recursive_pointee_cache
.end(cx, span, pointee_ty, pointee)
}
}
}
Expand All @@ -504,12 +483,12 @@ fn dig_scalar_pointee<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
) -> PointeeTy<'tcx> {
match *ty.ty.kind() {
TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => {
let elem = cx.layout_of(elem_ty);
match index {
None => (None, PointeeTy::Ty(elem)),
None => PointeeTy::Ty(elem),
Some(index) => {
if elem.is_unsized() {
dig_scalar_pointee(cx, ty.field(cx, index), None)
Expand All @@ -518,12 +497,12 @@ fn dig_scalar_pointee<'tcx>(
// of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other
// completely unrelated type, not a wide pointer. So, translate this as a single scalar, one
// component of that ScalarPair.
(None, PointeeTy::Ty(elem))
PointeeTy::Ty(elem)
}
}
}
}
TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)),
TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig),
TyKind::Adt(def, _) if def.is_box() => {
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
dig_scalar_pointee(cx, ptr_ty, index)
Expand All @@ -542,11 +521,8 @@ fn dig_scalar_pointee_adt<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
// Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only
// need to handle the attribute here.
let storage_class = get_storage_class(cx, ty);
let result = match &ty.variants {
) -> PointeeTy<'tcx> {
match &ty.variants {
// If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the
// discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying
// type, only available in the dataful variant.
Expand Down Expand Up @@ -603,21 +579,16 @@ fn dig_scalar_pointee_adt<'tcx>(
},
}
}
};
match (storage_class, result) {
(storage_class, (None, result)) => (storage_class, result),
(None, (storage_class, result)) => (storage_class, result),
(Some(one), (Some(two), _)) => cx.tcx.sess.fatal(&format!(
"Double-applied storage class ({:?} and {:?}) on type {}",
one, two, ty.ty
)),
}
}

/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the scalar translation code, because this is only
/// Handles `#[spirv(storage_class="blah")]`. Note this is only called in the entry interface variables code, because this is only
/// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these.
// TODO: Enforce this is only used in spirv-std.
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
pub(crate) fn get_storage_class<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
) -> Option<StorageClass> {
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
if let SpirvAttribute::StorageClass(storage_class) = attr {
Expand Down
62 changes: 15 additions & 47 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn alloca(&mut self, ty: Self::Type, _align: Align) -> Self::Value {
let ptr_ty = SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
}
.def(self.span(), self);
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(self.span(), self);
// "All OpVariable instructions in a function must be the first instructions in the first block."
let mut builder = self.emit();
builder.select_block(Some(0)).unwrap();
Expand Down Expand Up @@ -779,10 +775,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return value;
}
let ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"load called on variable that wasn't a pointer: {:?}",
ty
Expand All @@ -803,10 +796,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {

fn atomic_load(&mut self, ptr: Self::Value, order: AtomicOrdering, _size: Size) -> Self::Value {
let ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_load called on variable that wasn't a pointer: {:?}",
ty
Expand Down Expand Up @@ -906,10 +896,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {

fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
let ptr_elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"store called on variable that wasn't a pointer: {:?}",
ty
Expand Down Expand Up @@ -946,10 +933,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
_size: Size,
) {
let ptr_elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_store called on variable that wasn't a pointer: {:?}",
ty
Expand Down Expand Up @@ -979,15 +963,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn struct_gep(&mut self, ptr: Self::Value, idx: u64) -> Self::Value {
let (storage_class, result_pointee_type) = match self.lookup_type(ptr.ty) {
SpirvType::Pointer {
storage_class,
pointee,
} => match self.lookup_type(pointee) {
SpirvType::Adt { field_types, .. } => (storage_class, field_types[idx as usize]),
let result_pointee_type = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. } => (storage_class, element),
| SpirvType::Vector { element, .. } => element,
other => self.fatal(&format!(
"struct_gep not on struct, array, or vector type: {:?}, index {}",
other, idx
Expand All @@ -999,7 +980,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)),
};
let result_type = SpirvType::Pointer {
storage_class,
pointee: result_pointee_type,
}
.def(self.span(), self);
Expand Down Expand Up @@ -1225,14 +1205,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {

fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
let val_pointee = match self.lookup_type(val.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(&format!(
"pointercast called on non-pointer source type: {:?}",
other
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(&format!(
"pointercast called on non-pointer dest type: {:?}",
other
Expand All @@ -1249,12 +1229,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.access_chain(dest_ty, None, val.def(self), indices)
.unwrap()
.with_type(dest_ty)
} else if self
.really_unsafe_ignore_bitcasts
.borrow()
.contains(&self.current_fn)
{
val
} else {
let result = self
.emit()
Expand Down Expand Up @@ -1531,7 +1505,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return;
}
let src_pointee = match self.lookup_type(src.ty) {
SpirvType::Pointer { pointee, .. } => Some(pointee),
SpirvType::Pointer { pointee } => Some(pointee),
_ => None,
};
let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self));
Expand Down Expand Up @@ -1592,7 +1566,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
));
}
let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Pointer { pointee } => pointee,
_ => self.fatal(&format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
Expand Down Expand Up @@ -1780,10 +1754,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
_weak: bool,
) -> Self::Value {
let dst_pointee_ty = match self.lookup_type(dst.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_cmpxchg called on variable that wasn't a pointer: {:?}",
ty
Expand Down Expand Up @@ -1820,10 +1791,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
order: AtomicOrdering,
) -> Self::Value {
let dst_pointee_ty = match self.lookup_type(dst.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
SpirvType::Pointer { pointee } => pointee,
ty => self.fatal(&format!(
"atomic_rmw called on variable that wasn't a pointer: {:?}",
ty
Expand Down
Loading