Skip to content

Commit

Permalink
Add DataInstKind::Vector for pure vector ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed Jan 31, 2024
1 parent 51fde7a commit 47f7744
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 24 deletions.
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,12 @@ pub enum DataInstKind {
#[from]
Scalar(scalar::Op),

/// Vector (small array of [`scalar`]s) pure operations.
///
/// See also the [`vector`] module for more documentation and definitions.
#[from]
Vector(vector::Op),

// FIXME(eddyb) try to split this into recursive and non-recursive calls,
// to avoid needing special handling for recursion where it's impossible.
FuncCall(Func),
Expand Down
65 changes: 55 additions & 10 deletions src/print/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use crate::print::multiversion::Versions;
use crate::qptr::{self, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage};
use crate::visit::{InnerVisit, Visit, Visitor};
use crate::{
cfg, scalar, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context,
ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion,
cfg, scalar, spv, vector, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind,
Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion,
ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef,
DataInstKind, DeclDef, Diag, DiagLevel, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func,
FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody,
Expand Down Expand Up @@ -2407,7 +2407,7 @@ impl Print for ConstDef {

let kw = |kw| printer.declarative_keyword_style().apply(kw).into();

// FIXME(eddyb) should this just a method on `scalar::Const` instead?
// FIXME(eddyb) should this be a method on `scalar::Const` instead?
let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct {
scalar::Const::FALSE => kw("false"),
scalar::Const::TRUE => kw("true"),
Expand Down Expand Up @@ -3023,17 +3023,62 @@ impl Print for FuncAt<'_, DataInst> {

let mut output_type_to_print = *output_type;

// FIXME(eddyb) should this be a method on `scalar::Op` instead?
let print_scalar = |op: scalar::Op| {
let name = op.name();
let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1);
pretty::Fragment::new([
printer
.demote_style_for_namespace_prefix(printer.declarative_keyword_style())
.apply(namespace_prefix),
printer.declarative_keyword_style().apply(name),
])
};

let def_without_type = match kind {
&DataInstKind::Scalar(op) => {
let name = op.name();
&DataInstKind::Scalar(op) => pretty::Fragment::new([
print_scalar(op),
pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"),
]),

&DataInstKind::Vector(op) => {
let (name, extra_last_input) = match op {
vector::Op::Distribute(_) => ("vec.distribute", None),
vector::Op::Reduce(op) => (op.name(), None),
vector::Op::Whole(op) => (
op.name(),
match op {
vector::WholeOp::Extract { elem_idx }
| vector::WholeOp::Insert { elem_idx } => Some(
printer.numeric_literal_style().apply(elem_idx.to_string()).into(),
),
vector::WholeOp::New
| vector::WholeOp::DynExtract
| vector::WholeOp::DynInsert
| vector::WholeOp::Mul => None,
},
),
};
let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1);
pretty::Fragment::new([
let mut pretty_name = pretty::Fragment::new([
printer
.demote_style_for_namespace_prefix(printer.declarative_keyword_style())
.apply(namespace_prefix)
.into(),
printer.declarative_keyword_style().apply(name).into(),
pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"),
.apply(namespace_prefix),
printer.declarative_keyword_style().apply(name),
]);
if let vector::Op::Distribute(op) = op {
pretty_name = pretty::Fragment::new([
pretty_name,
pretty::join_comma_sep("(", [print_scalar(op)], ")"),
]);
}
pretty::Fragment::new([
pretty_name,
pretty::join_comma_sep(
"(",
inputs.iter().map(|v| v.print(printer)).chain(extra_last_input),
")",
),
])
}

Expand Down
2 changes: 1 addition & 1 deletion src/qptr/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ impl<'a> InferUsage<'a> {
});
};
match &data_inst_form_def.kind {
DataInstKind::Scalar(_) => {}
DataInstKind::Scalar(_) | DataInstKind::Vector(_) => {}

&DataInstKind::FuncCall(callee) => {
match self.infer_usage_in_func(module, callee) {
Expand Down
2 changes: 1 addition & 1 deletion src/qptr/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ impl LiftToSpvPtrInstsInFunc<'_> {
Ok((addr_space, self.lifter.layout_of(pointee_type)?))
};
let replacement_data_inst_def = match &data_inst_form_def.kind {
DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged),
DataInstKind::Scalar(_) | DataInstKind::Vector(_) => return Ok(Transformed::Unchanged),

&DataInstKind::FuncCall(_callee) => {
for &v in &data_inst_def.inputs {
Expand Down
5 changes: 4 additions & 1 deletion src/qptr/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> {

match data_inst_form_def.kind {
// Known semantics, no need to preserve SPIR-V pointer information.
DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return,
DataInstKind::Scalar(_)
| DataInstKind::Vector(_)
| DataInstKind::FuncCall(_)
| DataInstKind::QPtr(_) => return,

DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {}
}
Expand Down
94 changes: 85 additions & 9 deletions src/spv/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ macro_rules! def_mappable_ops {
(
type { $($ty_op:ident),+ $(,)? }
const { $($ct_op:ident),+ $(,)? }
data_inst { $($di_op:ident),+ $(,)? }
$($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })*
) => {
#[allow(non_snake_case)]
struct MappableOps {
$($ty_op: spec::Opcode,)+
$($ct_op: spec::Opcode,)+
$($di_op: spec::Opcode,)+
$($($variant_op: spec::Opcode,)+)*
}
impl MappableOps {
Expand All @@ -35,6 +37,7 @@ macro_rules! def_mappable_ops {
MappableOps {
$($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+
$($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+
$($di_op: spv_spec.instructions.lookup(stringify!($di_op)).unwrap(),)+
$($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)*
}
};
Expand Down Expand Up @@ -74,6 +77,11 @@ def_mappable_ops! {
OpConstantTrue,
OpConstant,
}
data_inst {
OpVectorExtractDynamic,
OpVectorInsertDynamic,
OpVectorTimesScalar,
}
scalar::BoolUnOp {
OpLogicalNot <=> Not,
}
Expand Down Expand Up @@ -164,6 +172,11 @@ def_mappable_ops! {
OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le),
OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge),
}
vector::ReduceOp {
OpDot <=> Dot,
OpAny <=> Any,
OpAll <=> All,
}
}

impl scalar::Const {
Expand Down Expand Up @@ -424,16 +437,46 @@ impl spv::Inst {
if let Some(op) = scalar_op {
assert_eq!(imms.len(), 0);

// FIXME(eddyb) support vector versions of these ops as well.
if output_types.len() == op.output_count()
&& output_types.iter().all(|ty| ty.as_scalar(cx).is_some())
{
Some(op.into())
let (_scalar_type, vec_elem_count) = (output_types.len() == op.output_count())
.then(|| {
output_types.iter().map(|&ty| match cx[ty].kind {
TypeKind::Scalar(ty) => Some((ty, None)),
TypeKind::Vector(ty) => Some((ty.elem, Some(ty.elem_count))),
_ => None,
})
})
.and_then(|mut outputs| {
let first = outputs.next().unwrap()?;
outputs.all(|x| x == Some(first)).then_some(first)
})?;

Some(if vec_elem_count.is_some() {
vector::Op::Distribute(op).into()
} else {
None
}
op.into()
})
} else if let Some(op) = vector::ReduceOp::try_from_opcode(opcode).map(vector::Op::from) {
assert_eq!(imms.len(), 0);
Some(op.into())
} else {
None
let wk = &spec::Spec::get().well_known;
let mo = MappableOps::get();

// FIXME(eddyb) automate this by supporting immediates in the macro.
let v_whole = |op| Some(vector::Op::Whole(op).into());
match imms {
[] if opcode == wk.OpCompositeConstruct => v_whole(vector::WholeOp::New),
&[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeExtract => {
v_whole(vector::WholeOp::Extract { elem_idx: elem_idx.try_into().ok()? })
}
&[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeInsert => {
v_whole(vector::WholeOp::Insert { elem_idx: elem_idx.try_into().ok()? })
}
[] if opcode == mo.OpVectorExtractDynamic => v_whole(vector::WholeOp::DynExtract),
[] if opcode == mo.OpVectorInsertDynamic => v_whole(vector::WholeOp::DynInsert),
[] if opcode == mo.OpVectorTimesScalar => v_whole(vector::WholeOp::Mul),
_ => None,
}
}
}

Expand All @@ -447,7 +490,40 @@ impl spv::Inst {
scalar::Op::FloatUnary(op) => op.to_opcode().into(),
scalar::Op::FloatBinary(op) => op.to_opcode().into(),
}),
_ => None,
&DataInstKind::Vector(op) => Some(match op {
vector::Op::Distribute(op) => {
Self::from_canonical_data_inst_kind(&DataInstKind::Scalar(op)).unwrap()
}
vector::Op::Reduce(op) => op.to_opcode().into(),
vector::Op::Whole(op) => {
let wk = &spec::Spec::get().well_known;
let mo = MappableOps::get();

// FIXME(eddyb) automate this by supporting immediates in the macro.
match op {
vector::WholeOp::New => wk.OpCompositeConstruct.into(),
vector::WholeOp::Extract { elem_idx } => spv::Inst {
opcode: wk.OpCompositeExtract,
imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())]
.into_iter()
.collect(),
},
vector::WholeOp::Insert { elem_idx } => spv::Inst {
opcode: wk.OpCompositeInsert,
imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())]
.into_iter()
.collect(),
},
vector::WholeOp::DynExtract => mo.OpVectorExtractDynamic.into(),
vector::WholeOp::DynInsert => mo.OpVectorInsertDynamic.into(),
vector::WholeOp::Mul => mo.OpVectorTimesScalar.into(),
}
}
}),
DataInstKind::FuncCall(_)
| DataInstKind::QPtr(_)
| DataInstKind::SpvInst(..)
| DataInstKind::SpvExtInst { .. } => None,
}
}
}
8 changes: 6 additions & 2 deletions src/spv/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ impl Visitor<'_> for NeedsIdsCollector<'_> {
unreachable!("`DataInstKind::QPtr` should be legalized away before lifting");
}

DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::SpvInst(_) => {}
DataInstKind::Scalar(_)
| DataInstKind::Vector(_)
| DataInstKind::FuncCall(_)
| DataInstKind::SpvInst(_) => {}

DataInstKind::SpvExtInst { ext_set, .. } => {
self.ext_inst_imports.insert(&self.cx[ext_set]);
}
Expand Down Expand Up @@ -1286,7 +1290,7 @@ impl LazyInst<'_, '_> {
match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) {
Ok(spv_inst) => (spv_inst, None),

Err(DataInstKind::Scalar(_)) => {
Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => {
unreachable!("should've been handled as canonical")
}

Expand Down
6 changes: 6 additions & 0 deletions src/spv/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def_well_known! {
OpTypeSampledImage,
OpTypeAccelerationStructureKHR,

// FIXME(eddyb) hide these from code, lowering should handle most cases.
OpConstantComposite,

OpVariable,
Expand Down Expand Up @@ -159,6 +160,11 @@ def_well_known! {
OpPtrAccessChain,
OpInBoundsPtrAccessChain,
OpBitcast,

// FIXME(eddyb) hide these from code, lowering should handle most cases.
OpCompositeInsert,
OpCompositeExtract,
OpCompositeConstruct,
],
operand_kind: OperandKind = [
Capability,
Expand Down
1 change: 1 addition & 0 deletions src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ impl InnerTransform for DataInstFormDef {
| QPtrOp::Store => Transformed::Unchanged,
},
DataInstKind::Scalar(_)
| DataInstKind::Vector(_)
| DataInstKind::SpvInst(_)
| DataInstKind::SpvExtInst { .. } => Transformed::Unchanged,
},
Expand Down
57 changes: 57 additions & 0 deletions src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,60 @@ impl Const {
(0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap())
}
}

/// Pure operations with vector inputs and/or outputs.
#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)]
pub enum Op {
Distribute(scalar::Op),
Reduce(ReduceOp),

// FIXME(eddyb) find a better name for this category of ops.
Whole(WholeOp),
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ReduceOp {
// FIXME(eddyb) also support all the new integer dot product instructions.
Dot,
// FIXME(eddyb) model these using their respective `BoolBinOp`s?
Any,
All,
}

impl ReduceOp {
pub fn name(self) -> &'static str {
match self {
ReduceOp::Dot => "vec.dot",
ReduceOp::Any => "vec.any",
ReduceOp::All => "vec.all",
}
}
}

// FIXME(eddyb) find a better name for this category of ops.
// FIXME(eddyb) also support `OpVectorShuffle`.
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum WholeOp {
// FIXME(eddyb) better name for this (pack? make? "construct" is too long).
New,
Extract { elem_idx: u8 },
Insert { elem_idx: u8 },
DynExtract,
DynInsert,

// FIXME(eddyb) may need a better name to indicate "scalar product".
Mul,
}

impl WholeOp {
pub fn name(self) -> &'static str {
match self {
WholeOp::New => "vec.new",
WholeOp::Extract { .. } => "vec.extract",
WholeOp::Insert { .. } => "vec.insert",
WholeOp::DynExtract => "vec.dyn_extract",
WholeOp::DynInsert => "vec.dyn_insert",
WholeOp::Mul => "vec.mul",
}
}
}
1 change: 1 addition & 0 deletions src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ impl InnerVisit for DataInstFormDef {
| QPtrOp::Store => {}
},
DataInstKind::Scalar(_)
| DataInstKind::Vector(_)
| DataInstKind::SpvInst(_)
| DataInstKind::SpvExtInst { .. } => {}
}
Expand Down

0 comments on commit 47f7744

Please sign in to comment.