Skip to content

Commit

Permalink
Add insert/extractBits and pack/unpack functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Oct 5, 2021
1 parent da00bf2 commit 9ceae73
Show file tree
Hide file tree
Showing 16 changed files with 541 additions and 38 deletions.
4 changes: 4 additions & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ fn write_fun(
arg,
arg1,
arg2,
arg3,
} => {
edges.insert("arg", arg);
if let Some(expr) = arg1 {
Expand All @@ -330,6 +331,9 @@ fn write_fun(
if let Some(expr) = arg2 {
edges.insert("arg2", expr);
}
if let Some(expr) = arg3 {
edges.insert("arg3", expr);
}
(format!("{:?}", fun).into(), 7)
}
E::As {
Expand Down
19 changes: 19 additions & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2321,6 +2321,7 @@ impl<'a, W: Write> Writer<'a, W> {
arg,
arg1,
arg2,
arg3,
} => {
use crate::MathFunction as Mf;

Expand Down Expand Up @@ -2385,6 +2386,20 @@ impl<'a, W: Write> Writer<'a, W> {
// bits
Mf::CountOneBits => "bitCount",
Mf::ReverseBits => "bitfieldReverse",
Mf::ExtractBits => "bitfieldExtract",
Mf::InsertBits => "bitfieldInsert",
// data packing
Mf::Pack4x8snorm => "packSnorm4x8",
Mf::Pack4x8unorm => "packUnorm4x8",
Mf::Pack2x16snorm => "packSnorm2x16",
Mf::Pack2x16unorm => "packUnorm2x16",
Mf::Pack2x16float => "packHalf2x16",
// data unpacking
Mf::Unpack4x8snorm => "unpackSnorm4x8",
Mf::Unpack4x8unorm => "unpackUnorm4x8",
Mf::Unpack2x16snorm => "unpackSnorm2x16",
Mf::Unpack2x16unorm => "unpackUnorm2x16",
Mf::Unpack2x16float => "unpackHalf2x16",
};

write!(self.out, "{}(", fun_name)?;
Expand All @@ -2397,6 +2412,10 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, ", ")?;
self.write_expr(arg, ctx)?;
}
if let Some(arg) = arg3 {
write!(self.out, ", ")?;
self.write_expr(arg, ctx)?;
}
write!(self.out, ")")?
}
// `As` is always a call.
Expand Down
5 changes: 5 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1816,6 +1816,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
arg,
arg1,
arg2,
arg3,
} => {
use crate::MathFunction as Mf;

Expand Down Expand Up @@ -1918,6 +1919,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
if let Some(arg) = arg3 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
}
Expand Down
32 changes: 29 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ impl<W: Write> Writer<W> {
arg,
arg1,
arg2,
arg3,
} => {
use crate::MathFunction as Mf;

Expand Down Expand Up @@ -1178,6 +1179,20 @@ impl<W: Write> Writer<W> {
// bits
Mf::CountOneBits => "popcount",
Mf::ReverseBits => "reverse_bits",
Mf::ExtractBits => "extract_bits",
Mf::InsertBits => "insert_bits",
// data packing
Mf::Pack4x8snorm => "pack_float_to_unorm4x8",
Mf::Pack4x8unorm => "pack_float_to_snorm4x8",
Mf::Pack2x16snorm => "pack_float_to_unorm2x16",
Mf::Pack2x16unorm => "pack_float_to_snorm2x16",
Mf::Pack2x16float => "",
// data unpacking
Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
Mf::Unpack2x16float => "",
};

if fun == Mf::Distance && scalar_argument {
Expand All @@ -1186,9 +1201,20 @@ impl<W: Write> Writer<W> {
write!(self.out, " - ")?;
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
} else if fun == Mf::Pack2x16float {
write!(self.out, "as_type<uint>(half2(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
} else {
write!(self.out, "{}::{}", NAMESPACE, fun_name)?;
self.put_call_parameters(iter::once(arg).chain(arg1).chain(arg2), context)?;
self.put_call_parameters(
iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
context,
)?;
}
}
crate::Expression::As {
Expand Down Expand Up @@ -2661,8 +2687,8 @@ fn test_stack_size() {
}
let stack_size = addresses.end - addresses.start;
// check the size (in debug only)
// last observed macOS value: 18304
if !(15000..=20000).contains(&stack_size) {
// last observed macOS value: 20528 (CI)
if !(15000..=25000).contains(&stack_size) {
panic!("`put_expression` stack size {} has changed!", stack_size);
}
}
Expand Down
41 changes: 40 additions & 1 deletion src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ impl<'w> BlockContext<'w> {
arg,
arg1,
arg2,
arg3,
} => {
use crate::MathFunction as Mf;
enum MathOp {
Expand All @@ -457,6 +458,10 @@ impl<'w> BlockContext<'w> {
Some(handle) => self.cached[handle],
None => 0,
};
let arg3_id = match arg3 {
Some(handle) => self.cached[handle],
None => 0,
};

let id = self.gen_id();
let math_op = match fun {
Expand Down Expand Up @@ -606,6 +611,40 @@ impl<'w> BlockContext<'w> {
log::error!("unimplemented math function {:?}", fun);
return Err(Error::FeatureNotImplemented("math function"));
}
Mf::ExtractBits => {
let op = match arg_scalar_kind {
Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
other => unimplemented!("Unexpected sign({:?})", other),
};
MathOp::Custom(Instruction::ternary(
op,
result_type_id,
id,
arg0_id,
arg1_id,
arg2_id,
))
}
Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
spirv::Op::BitFieldInsert,
result_type_id,
id,
arg0_id,
arg1_id,
arg2_id,
arg3_id,
)),
Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
};

block.body.push(match math_op {
Expand All @@ -614,7 +653,7 @@ impl<'w> BlockContext<'w> {
op,
result_type_id,
id,
&[arg0_id, arg1_id, arg2_id][..fun.argument_count()],
&[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
),
MathOp::Custom(inst) => inst,
});
Expand Down
36 changes: 36 additions & 0 deletions src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,42 @@ impl super::Instruction {
instruction
}

pub(super) fn ternary(
op: Op,
result_type_id: Word,
id: Word,
operand_1: Word,
operand_2: Word,
operand_3: Word,
) -> Self {
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(operand_1);
instruction.add_operand(operand_2);
instruction.add_operand(operand_3);
instruction
}

pub(super) fn quaternary(
op: Op,
result_type_id: Word,
id: Word,
operand_1: Word,
operand_2: Word,
operand_3: Word,
operand_4: Word,
) -> Self {
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(operand_1);
instruction.add_operand(operand_2);
instruction.add_operand(operand_3);
instruction.add_operand(operand_4);
instruction
}

pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
Expand Down
19 changes: 19 additions & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,7 @@ impl<W: Write> Writer<W> {
arg,
arg1,
arg2,
arg3,
} => {
use crate::MathFunction as Mf;

Expand Down Expand Up @@ -1523,6 +1524,20 @@ impl<W: Write> Writer<W> {
// bits
Mf::CountOneBits => Function::Regular("countOneBits"),
Mf::ReverseBits => Function::Regular("reverseBits"),
Mf::ExtractBits => Function::Regular("extractBits"),
Mf::InsertBits => Function::Regular("insertBits"),
// data packing
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
Mf::Pack2x16float => Function::Regular("pack2x16float"),
// data unpacking
Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
_ => {
return Err(Error::UnsupportedMathFunction(fun));
}
Expand Down Expand Up @@ -1559,6 +1574,10 @@ impl<W: Write> Writer<W> {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
if let Some(arg) = arg3 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
}
Expand Down
Loading

0 comments on commit 9ceae73

Please sign in to comment.