Skip to content

Commit

Permalink
Generate dot() in the Metal backend (#7085)
Browse files Browse the repository at this point in the history
* dot() support for Metal backend)

* Restrict dot() to floats
  • Loading branch information
vksnk authored and steven-johnson committed Oct 24, 2022
1 parent e7dfaac commit eca189b
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev {
void visit(const Allocate *op) override;
void visit(const Free *op) override;
void visit(const Cast *op) override;
void visit(const VectorReduce *op) override;
void visit(const Atomic *op) override;
};

Expand Down Expand Up @@ -223,6 +224,20 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Min *op) {
print_expr(Call::make(op->type, "min", {op->a, op->b}, Call::Extern));
}

void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const VectorReduce *op) {
if (op->op == VectorReduce::Add && op->type.is_float() && (op->type.lanes() == 1)) {
if (const Mul *maybe_mul = op->value.as<Mul>()) {
string a = print_expr(maybe_mul->a);
string b = print_expr(maybe_mul->b);
ostringstream rhs;
rhs << "dot(" << a << ", " << b << ")";
print_assignment(op->type, rhs.str());
return;
}
}
CodeGen_GPU_C::visit(op);
}

void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Div *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
Expand Down

0 comments on commit eca189b

Please sign in to comment.