From f559a6b16e8a6efbad43e9d4daa3240231616c77 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Fri, 31 Mar 2023 16:34:25 +0300 Subject: [PATCH] Integrate the SPIR-T `qptr` experiment. --- CHANGELOG.md | 7 +- Cargo.lock | 18 +++- crates/rustc_codegen_spirv/Cargo.toml | 2 +- .../rustc_codegen_spirv/src/codegen_cx/mod.rs | 6 ++ crates/rustc_codegen_spirv/src/linker/mod.rs | 22 ++--- .../src/linker/spirt_passes/diagnostics.rs | 89 +++++++++++++++++-- .../src/linker/spirt_passes/mod.rs | 25 ++++++ 7 files changed, 143 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 902dc07286..bc26c5574a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added ⭐ -- [PR#1039](https://github.com/EmbarkStudios/rust-gpu/pull/1039) added new experimental `sample_with` to `Image` API to set additional image operands. +- [PR#1020](https://github.com/EmbarkStudios/rust-gpu/pull/1020) added SPIR-T `qptr` + support in the form of `--spirt-passes=qptr`, a way to turn off "Storage Class inference", + and reporting for SPIR-T diagnostics - to test `qptr` fully, you can use: + `RUSTGPU_CODEGEN_ARGS="--no-infer-storage-classes --spirt-passes=qptr"` + (see also [the SPIR-T `qptr` PR](https://github.com/EmbarkStudios/spirt/pull/24) for more details about the `qptr` experiment) +- [PR#1039](https://github.com/EmbarkStudios/rust-gpu/pull/1039) added new experimental `sample_with` to `Image` API to set additional image operands - [PR#1031](https://github.com/EmbarkStudios/rust-gpu/pull/1031) added `Components` generic parameter to `Image` type, allowing images to return lower dimensional vectors and even scalars from the sampling API ### Changed 🛠 diff --git a/Cargo.lock b/Cargo.lock index 93538bd4d2..245ca1759a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1106,6 +1106,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "internal-iterator" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668ef46056a63366da9d74f48062da9ece1a27958f2f3704aa6f7421c4433f5" + [[package]] name = "io-lifetimes" version = "1.0.6" @@ -1265,6 +1271,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "longest-increasing-subsequence" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3bd0dd2cd90571056fdb71f6275fada10131182f84899f4b2a916e565d81d86" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -2215,16 +2227,18 @@ dependencies = [ [[package]] name = "spirt" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06834ebbbbc6f86448fd5dc7ccbac80e36f52f8d66838683752e19d3cae9a459" +checksum = "e24fa996f12f3c667efbceaa99c222b8910a295a14d2c43c3880dfab2752def7" dependencies = [ "arrayvec", "bytemuck", "elsa", "indexmap", + "internal-iterator", "itertools", "lazy_static", + "longest-increasing-subsequence", "rustc-hash", "serde", "serde_json", diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index d24bd78143..953a09edb2 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -58,7 +58,7 @@ serde_json = "1.0" smallvec = { version = "1.6.1", features = ["union"] } spirv-tools = { version = "0.9", default-features = false } rustc_codegen_spirv-types.workspace = true -spirt = "0.1.0" +spirt = "0.2.0" lazy_static = "1.4.0" itertools = "0.10.5" diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index e900107ed6..36bcf41713 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -335,6 +335,11 @@ impl CodegenArgs { "no-early-report-zombies", "delays reporting zombies (to allow more legalization)", ); + opts.optflag( + "", + "no-infer-storage-classes", + "disables SPIR-V Storage Class inference", + ); opts.optflag("", "no-structurize", "disables CFG structurization"); opts.optflag( @@ -515,6 +520,7 @@ impl CodegenArgs { dce: !matches.opt_present("no-dce"), compact_ids: !matches.opt_present("no-compact-ids"), early_report_zombies: !matches.opt_present("no-early-report-zombies"), + infer_storage_classes: !matches.opt_present("no-infer-storage-classes"), structurize: !matches.opt_present("no-structurize"), spirt: !matches.opt_present("no-spirt"), spirt_passes: matches diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index afe740920e..998045dff1 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -39,6 +39,7 @@ pub struct Options { pub compact_ids: bool, pub dce: bool, pub early_report_zombies: bool, + pub infer_storage_classes: bool, pub structurize: bool, pub spirt: bool, pub spirt_passes: Vec, @@ -228,7 +229,7 @@ pub fn link( zombies::report_and_remove_zombies(sess, opts, &mut output)?; } - { + if opts.infer_storage_classes { // HACK(eddyb) this is not the best approach, but storage class inference // can still fail in entirely legitimate ways (i.e. mismatches in zombies). if !opts.early_report_zombies { @@ -408,6 +409,7 @@ pub fn link( } if !opts.spirt_passes.is_empty() { + // FIXME(eddyb) why does this focus on functions, it could just be module passes?? spirt_passes::run_func_passes( &mut module, &opts.spirt_passes, @@ -441,21 +443,13 @@ pub fn link( // FIXME(eddyb) don't allocate whole `String`s here. std::fs::write(&dump_spirt_file_path, pretty.to_string()).unwrap(); - std::fs::write(dump_spirt_file_path.with_extension("spirt.html"), { - let mut html = pretty + std::fs::write( + dump_spirt_file_path.with_extension("spirt.html"), + pretty .render_to_html() .with_dark_mode_support() - .to_html_doc(); - // HACK(eddyb) this should be in `spirt::pretty` itself, - // but its need didn't become obvious until more recently. - html += " - "; - html - }) + .to_html_doc(), + ) .unwrap(); } diff --git a/crates/rustc_codegen_spirv/src/linker/spirt_passes/diagnostics.rs b/crates/rustc_codegen_spirv/src/linker/spirt_passes/diagnostics.rs index 94e1370cb7..538057d506 100644 --- a/crates/rustc_codegen_spirv/src/linker/spirt_passes/diagnostics.rs +++ b/crates/rustc_codegen_spirv/src/linker/spirt_passes/diagnostics.rs @@ -1,13 +1,13 @@ use crate::decorations::{CustomDecoration, SpanRegenerator, SrcLocDecoration, ZombieDecoration}; use rustc_data_structures::fx::FxIndexSet; -use rustc_errors::{DiagnosticBuilder, ErrorGuaranteed}; +use rustc_errors::DiagnosticBuilder; use rustc_session::Session; use rustc_span::{Span, DUMMY_SP}; use smallvec::SmallVec; use spirt::visit::{InnerVisit, Visitor}; use spirt::{ - spv, Attr, AttrSet, AttrSetDef, Const, Context, DataInstDef, DataInstKind, ExportKey, Exportee, - Func, GlobalVar, Module, Type, + spv, Attr, AttrSet, AttrSetDef, Const, Context, DataInstDef, DataInstKind, Diag, DiagLevel, + ExportKey, Exportee, Func, GlobalVar, Module, Type, }; use std::marker::PhantomData; use std::{mem, str}; @@ -35,6 +35,7 @@ pub(crate) fn report_diagnostics( use_stack: SmallVec::new(), span_regen: SpanRegenerator::new_spirt(sess.source_map(), module), overall_result: Ok(()), + any_spirt_bugs: false, }; for (export_key, &exportee) in &module.exports { assert_eq!(reporter.use_stack.len(), 0); @@ -56,6 +57,28 @@ pub(crate) fn report_diagnostics( export_key.inner_visit_with(&mut reporter); exportee.inner_visit_with(&mut reporter); } + + if reporter.any_spirt_bugs { + let mut note = sess.struct_note_without_error("SPIR-T bugs were reported"); + match &linker_options.dump_spirt_passes { + Some(dump_dir) => { + note.help(format!( + "pretty-printed SPIR-T will be saved to `{}`, as `.spirt.html` files", + dump_dir.display() + )); + } + None => { + // FIXME(eddyb) maybe just always generate the files in a tmpdir? + note.help( + "re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` to \ + get pretty-printed SPIR-T (`.spirt.html`)", + ); + } + } + note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues") + .emit(); + } + reporter.overall_result } @@ -81,8 +104,7 @@ fn decode_spv_lit_str_with(imms: &[spv::Imm], f: impl FnOnce(&str) -> R) -> R let words = imms.iter().enumerate().map(|(i, &imm)| match (i, imm) { (0, spirt::spv::Imm::Short(k, w) | spirt::spv::Imm::LongStart(k, w)) | (1.., spirt::spv::Imm::LongCont(k, w)) => { - // FIXME(eddyb) use `assert_eq!` after updating to latest SPIR-T. - assert!(k == wk.LiteralString); + assert_eq!(k, wk.LiteralString); w } _ => unreachable!(), @@ -138,6 +160,7 @@ struct DiagnosticReporter<'a> { use_stack: SmallVec<[UseOrigin<'a>; 8]>, span_regen: SpanRegenerator<'a>, overall_result: crate::linker::Result<()>, + any_spirt_bugs: bool, } enum UseOrigin<'a> { @@ -198,7 +221,7 @@ impl UseOrigin<'_> { &self, cx: &Context, span_regen: &mut SpanRegenerator<'_>, - err: &mut DiagnosticBuilder<'_, ErrorGuaranteed>, + err: &mut DiagnosticBuilder<'_, impl rustc_errors::EmissionGuarantee>, ) { let wk = &super::SpvSpecWithExtras::get().well_known; @@ -231,8 +254,7 @@ impl UseOrigin<'_> { &ExportKey::LinkName(name) => format!("function export `{}`", &cx[name]), ExportKey::SpvEntryPoint { imms, .. } => match imms[..] { [em @ spv::Imm::Short(em_kind, _), ref name_imms @ ..] => { - // FIXME(eddyb) use `assert_eq!` after updating to latest SPIR-T. - assert!(em_kind == wk.ExecutionModel); + assert_eq!(em_kind, wk.ExecutionModel); let em = spv::print::operand_from_imms([em]).concat_to_plain_text(); decode_spv_lit_str_with(name_imms, |name| { format!( @@ -299,6 +321,57 @@ impl DiagnosticReporter<'_> { self.overall_result = Err(err.emit()); } } + + let diags = attrs_def.attrs.iter().flat_map(|attr| match attr { + Attr::Diagnostics(diags) => diags.0.iter(), + _ => [].iter(), + }); + for diag in diags { + let Diag { level, message } = diag; + + let prefix = match level { + DiagLevel::Bug(location) => { + let location = location.to_string(); + let location = match location.split_once("/src/") { + Some((_path_prefix, intra_src)) => intra_src, + None => &location, + }; + format!("SPIR-T BUG [{location}] ") + } + DiagLevel::Error | DiagLevel::Warning => "".to_string(), + }; + let (deps, msg) = spirt::print::Plan::for_root(self.cx, message) + .pretty_print_deps_and_root_separately(); + + let deps = deps.to_string(); + let suffix = if !deps.is_empty() { + format!("\n where\n {}", deps.replace('\n', "\n ")) + } else { + "".to_string() + }; + + let def_span = current_def + .and_then(|def| def.to_rustc_span(self.cx, &mut self.span_regen)) + .unwrap_or(DUMMY_SP); + + let msg = [prefix, msg.to_string(), suffix].concat(); + match level { + DiagLevel::Bug(_) | DiagLevel::Error => { + let mut err = self.sess.struct_span_err(def_span, msg); + for use_origin in use_stack_for_def.iter().rev() { + use_origin.note(self.cx, &mut self.span_regen, &mut err); + } + self.overall_result = Err(err.emit()); + } + DiagLevel::Warning => { + let mut warn = self.sess.struct_span_warn(def_span, msg); + for use_origin in use_stack_for_def.iter().rev() { + use_origin.note(self.cx, &mut self.span_regen, &mut warn); + } + } + } + self.any_spirt_bugs = matches!(level, DiagLevel::Bug(_)); + } } } diff --git a/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs b/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs index 1772234e3b..4c0d77c328 100644 --- a/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs @@ -108,6 +108,7 @@ def_spv_spec_with_extra_well_known! { /// Run intra-function passes on all `Func` definitions in the `Module`. // // FIXME(eddyb) introduce a proper "pass manager". +// FIXME(eddyb) why does this focus on functions, it could just be module passes?? pub(super) fn run_func_passes

( module: &mut Module, passes: &[impl AsRef], @@ -137,6 +138,30 @@ pub(super) fn run_func_passes

( for name in passes { let name = name.as_ref(); + + // HACK(eddyb) not really a function pass. + if name == "qptr" { + let layout_config = &spirt::qptr::LayoutConfig { + abstract_bool_size_align: (1, 1), + logical_ptr_size_align: (4, 4), + ..spirt::qptr::LayoutConfig::VULKAN_SCALAR_LAYOUT + }; + + let profiler = before_pass("qptr::lower_from_spv_ptrs", module); + spirt::passes::qptr::lower_from_spv_ptrs(module, layout_config); + after_pass("qptr::lower_from_spv_ptrs", module, profiler); + + let profiler = before_pass("qptr::analyze_uses", module); + spirt::passes::qptr::analyze_uses(module, layout_config); + after_pass("qptr::analyze_uses", module, profiler); + + let profiler = before_pass("qptr::lift_to_spv_ptrs", module); + spirt::passes::qptr::lift_to_spv_ptrs(module, layout_config); + after_pass("qptr::lift_to_spv_ptrs", module, profiler); + + continue; + } + let (full_name, pass_fn): (_, fn(_, &mut _)) = match name { "reduce" => ("spirt_passes::reduce", reduce::reduce_in_func), "fuse_selects" => (