Skip to content

Commit

Permalink
Add TypeKind::Vector&ConstKind::Scalar for vector types&consts.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed Jan 31, 2024
1 parent f61e300 commit 44faddc
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 189 deletions.
36 changes: 34 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ pub mod passes {
pub mod qptr;
pub mod scalar;
pub mod spv;
pub mod vector;

use smallvec::SmallVec;
use std::borrow::Cow;
Expand Down Expand Up @@ -471,6 +472,13 @@ pub enum TypeKind {
#[from]
Scalar(scalar::Type),

/// Vector (small array of [`scalar`]s) type, with some limitations on the
/// supported component counts (but all standard ones should be included).
///
/// See also the [`vector`] module for more documentation and definitions.
#[from]
Vector(vector::Type),

/// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent
/// both memory locations (in any address space) and other kinds of locations
/// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes").
Expand Down Expand Up @@ -509,7 +517,7 @@ macro_rules! impl_intern_type_kind {
})+
}
}
impl_intern_type_kind!(TypeKind, scalar::Type);
impl_intern_type_kind!(TypeKind, scalar::Type, vector::Type);

// HACK(eddyb) this is like `Either<Type, Const>`, only used in `TypeKind::SpvInst`,
// and only because SPIR-V type definitions can references both types and consts.
Expand All @@ -527,6 +535,12 @@ impl Type {
_ => None,
}
}
pub fn as_vector(self, cx: &Context) -> Option<vector::Type> {
match cx[self].kind {
TypeKind::Vector(ty) => Some(ty),
_ => None,
}
}
}

/// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value).
Expand Down Expand Up @@ -562,6 +576,18 @@ pub enum ConstKind {
#[from]
Scalar(scalar::Const),

/// Vector (small array of [`scalar`]s) constant, which must have
/// a type of [`TypeKind::Vector`] (of the same [`vector::Type`]).
///
/// See also the [`vector`] module for more documentation and definitions.
//
// FIXME(eddyb) maybe document the 128-bit limitation inherited from `scalar::Const`?
// FIXME(eddyb) this technically makes the `vector::Type` redundant, could
// it get out of sync? (perhaps "forced canonicalization" could be used to
// enforce that interning simply doesn't allow such scenarios?).
#[from]
Vector(vector::Const),

PtrToGlobalVar(GlobalVar),

// HACK(eddyb) this is a fallback case that should become increasingly rare
Expand Down Expand Up @@ -592,7 +618,7 @@ macro_rules! impl_intern_const_kind {
})+
}
}
impl_intern_const_kind!(scalar::Const);
impl_intern_const_kind!(scalar::Const, vector::Const);

// HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons.
impl Const {
Expand All @@ -602,6 +628,12 @@ impl Const {
_ => None,
}
}
pub fn as_vector(self, cx: &Context) -> Option<&vector::Const> {
match &cx[self].kind {
ConstKind::Vector(ct) => Some(ct),
_ => None,
}
}
}

/// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition,
Expand Down
149 changes: 65 additions & 84 deletions src/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,6 @@ enum UseStyle {
impl<'a> Printer<'a> {
fn new(plan: &Plan<'a>) -> Self {
let cx = plan.cx;
let wk = &spv::spec::Spec::get().well_known;

// HACK(eddyb) move this elsewhere.
enum SmallSet<T, const N: usize> {
Expand Down Expand Up @@ -813,21 +812,18 @@ impl<'a> Printer<'a> {
CxInterned::Type(ty) => {
let ty_def = &cx[ty];

// FIXME(eddyb) remove the duplication between
// here and `TypeDef`'s `Print` impl.
let has_compact_print_or_is_leaf = match &ty_def.kind {
TypeKind::SpvInst { spv_inst, type_and_const_inputs } => {
spv_inst.opcode == wk.OpTypeVector
|| type_and_const_inputs.is_empty()
let is_leaf = match &ty_def.kind {
TypeKind::SpvInst { type_and_const_inputs, .. } => {
type_and_const_inputs.is_empty()
}

TypeKind::Scalar(_)
| TypeKind::Vector(_)
| TypeKind::QPtr
| TypeKind::SpvStringLiteralForExtInst => true,
};

ty_def.attrs == AttrSet::default()
&& has_compact_print_or_is_leaf
ty_def.attrs == AttrSet::default() && is_leaf
}
CxInterned::Const(ct) => {
let ct_def = &cx[ct];
Expand Down Expand Up @@ -2360,70 +2356,43 @@ impl Print for TypeDef {

let wk = &spv::spec::Spec::get().well_known;

// FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T?
let kw = |kw| printer.declarative_keyword_style().apply(kw).into();
#[allow(irrefutable_let_patterns)]
let compact_def = if let &TypeKind::SpvInst {
spv_inst: spv::Inst { opcode, ref imms },
ref type_and_const_inputs,
} = kind
{
if opcode == wk.OpTypeVector {
let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) {
(&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => {
(elem_ty, elem_count)
}
_ => unreachable!(),
};

Some(pretty::Fragment::new([
elem_ty.print(printer),
"×".into(),
printer.numeric_literal_style().apply(format!("{elem_count}")).into(),
]))
} else {
None
// FIXME(eddyb) should this just be `fmt::Display` on `scalar::Type`?
let print_scalar = |ty: scalar::Type| {
let width = ty.bit_width();
match ty {
scalar::Type::Bool => "bool".into(),
scalar::Type::SInt(_) => format!("s{width}"),
scalar::Type::UInt(_) => format!("u{width}"),
scalar::Type::Float(_) => format!("f{width}"),
}
} else {
None
};

AttrsAndDef {
attrs: attrs.print(printer),
def_without_name: if let Some(def) = compact_def {
def
} else {
match kind {
TypeKind::Scalar(ty) => {
let width = ty.bit_width();
kw(match ty {
scalar::Type::Bool => "bool".into(),
scalar::Type::SInt(_) => format!("s{width}"),
scalar::Type::UInt(_) => format!("u{width}"),
scalar::Type::Float(_) => format!("f{width}"),
})
}

// FIXME(eddyb) should this be shortened to `qtr`?
TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(),

TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer
.pretty_spv_inst(
printer.spv_op_style(),
spv_inst.opcode,
&spv_inst.imms,
type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct {
TypeOrConst::Type(ty) => ty.print(printer),
TypeOrConst::Const(ct) => ct.print(printer),
}),
),
TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([
printer.error_style().apply("type_of").into(),
"(".into(),
printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString),
")".into(),
]),
}
def_without_name: match kind {
&TypeKind::Scalar(ty) => kw(print_scalar(ty)),
&TypeKind::Vector(ty) => kw(format!("{}×{}", print_scalar(ty.elem), ty.elem_count)),

// FIXME(eddyb) should this be shortened to `qtr`?
TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(),

TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer.pretty_spv_inst(
printer.spv_op_style(),
spv_inst.opcode,
&spv_inst.imms,
type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct {
TypeOrConst::Type(ty) => ty.print(printer),
TypeOrConst::Const(ct) => ct.print(printer),
}),
),
TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([
printer.error_style().apply("type_of").into(),
"(".into(),
printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString),
")".into(),
]),
},
}
}
Expand All @@ -2438,14 +2407,11 @@ impl Print for ConstDef {

let kw = |kw| printer.declarative_keyword_style().apply(kw).into();

let def_without_name = match kind {
ConstKind::Undef => pretty::Fragment::new([
printer.imperative_keyword_style().apply("undef").into(),
printer.pretty_type_ascription_suffix(*ty),
]),
ConstKind::Scalar(scalar::Const::FALSE) => kw("false"),
ConstKind::Scalar(scalar::Const::TRUE) => kw("true"),
ConstKind::Scalar(ct) => {
// FIXME(eddyb) should this just a method on `scalar::Const` instead?
let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct {
scalar::Const::FALSE => kw("false"),
scalar::Const::TRUE => kw("true"),
_ => {
let ty = ct.ty();
let width = ty.bit_width();
let (maybe_printed_value, ty_prefix) = match ty {
Expand Down Expand Up @@ -2492,17 +2458,19 @@ impl Print for ConstDef {
};
match maybe_printed_value {
Some(printed_value) => {
let literal_ty_suffix = pretty::Styles {
// HACK(eddyb) the exact type detracts from the value.
color_opacity: Some(0.4),
subscript: true,
..printer.declarative_keyword_style()
let printed_value = printer.numeric_literal_style().apply(printed_value);
if include_type_suffix {
let literal_ty_suffix = pretty::Styles {
// HACK(eddyb) the exact type detracts from the value.
color_opacity: Some(0.4),
subscript: true,
..printer.declarative_keyword_style()
}
.apply(format!("{ty_prefix}{width}"));
pretty::Fragment::new([printed_value, literal_ty_suffix])
} else {
printed_value.into()
}
.apply(format!("{ty_prefix}{width}"));
pretty::Fragment::new([
printer.numeric_literal_style().apply(printed_value),
literal_ty_suffix,
])
}
// HACK(eddyb) fallback using the bitwise representation.
None => pretty::Fragment::new([
Expand All @@ -2523,6 +2491,18 @@ impl Print for ConstDef {
]),
}
}
};

let def_without_name = match kind {
ConstKind::Undef => pretty::Fragment::new([
printer.imperative_keyword_style().apply("undef").into(),
printer.pretty_type_ascription_suffix(*ty),
]),
&ConstKind::Scalar(ct) => print_scalar(ct, true),
ConstKind::Vector(ct) => pretty::Fragment::new([
ty.print(printer),
pretty::join_comma_sep("(", ct.elems().map(|elem| print_scalar(elem, false)), ")"),
]),
&ConstKind::PtrToGlobalVar(gv) => {
pretty::Fragment::new(["&".into(), gv.print(printer)])
}
Expand Down Expand Up @@ -3251,6 +3231,7 @@ impl Print for FuncAt<'_, DataInst> {
if let Value::Const(ct) = v {
match &printer.cx[ct].kind {
ConstKind::Undef
| ConstKind::Vector(_)
| ConstKind::PtrToGlobalVar(_)
| ConstKind::SpvInst { .. } => {}

Expand Down
31 changes: 19 additions & 12 deletions src/qptr/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ impl<'a> LayoutCache<'a> {
}
TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())),

TypeKind::Vector(ty) => {
let len = u32::from(ty.elem_count.get());
return array(
cx.intern(ty.elem),
ArrayParams {
fixed_len: Some(len),
known_stride: None,

// NOTE(eddyb) this is specifically Vulkan "base alignment".
min_legacy_align: 1,
legacy_align_multiplier: if len <= 2 { 2 } else { 4 },
},
);
}

// FIXME(eddyb) treat `QPtr`s as scalars.
TypeKind::QPtr => {
return Err(LayoutError(Diag::bug(
Expand All @@ -359,26 +374,18 @@ impl<'a> LayoutCache<'a> {
// FIXME(eddyb) categorize `OpTypePointer` by storage class and split on
// logical vs physical here.
scalar_with_size_and_align(self.config.logical_ptr_size_align)
} else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) {
let len = short_imm_at(0);
let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector
{
// NOTE(eddyb) this is specifically Vulkan "base alignment".
(1, if len <= 2 { 2 } else { 4 })
} else {
(self.config.min_aggregate_legacy_align, 1)
};
} else if spv_inst.opcode == wk.OpTypeMatrix {
// NOTE(eddyb) `RowMajor` is disallowed on `OpTypeStruct` members below.
array(
match type_and_const_inputs[..] {
[TypeOrConst::Type(elem_type)] => elem_type,
_ => unreachable!(),
},
ArrayParams {
fixed_len: Some(len),
fixed_len: Some(short_imm_at(0)),
known_stride: None,
min_legacy_align,
legacy_align_multiplier,
min_legacy_align: self.config.min_aggregate_legacy_align,
legacy_align_multiplier: 1,
},
)?
} else if [wk.OpTypeArray, wk.OpTypeRuntimeArray].contains(&spv_inst.opcode) {
Expand Down
Loading

0 comments on commit 44faddc

Please sign in to comment.