Skip to content

Commit

Permalink
Nit Arithmetic cleanup (#2246)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Pull Request resolved: #2246

Facilitate code review before the big refactoring. Create a `maybe_prepack()` helper and improve variable naming.
ghstack-source-id: 217394066
exported-using-ghexport

Reviewed By: SS-JIA

Differential Revision: D54400674

fbshipit-source-id: 1912eabeaa9882d56b3b2a7a30b1ceb28a377559
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Mar 5, 2024
1 parent dfb5f51 commit fae9ef0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 35 deletions.
54 changes: 25 additions & 29 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,40 @@ DEFINE_ARITHMETIC_FN(pow, POW);

ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const int64_t shared_object_idx) {
std::vector<int64_t> t1_sizes = graph.get_val_sizes(t1);
api::ScalarType t1_dtype = graph.get_val_dtype(t1);
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
api::ScalarType in1_dtype = graph.get_val_dtype(in1);

ValueRef out = graph.add_tensor(t1_sizes, t1_dtype, shared_object_idx);
add_arithmetic_node(graph, t1, t2, out, alpha, optype);
ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
return out;
}

// TODO(T181006464): Move to Utils when we remove ArithmeticPrepack.
ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) {
if (graph.get_val(v).isTensor()) {
return v;
} else {
TensorRef& tRef = graph.get_val(v).toTensorRef();
ValueRef vTen = graph.add_tensor(tRef.sizes, tRef.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(v, vTen));
return vTen;
}
}

void add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype) {
// Prepacking first arg (if needed)
ValueRef arg1 = t1;
if (graph.get_val(t1).isTensorRef()) {
TensorRef& t1_asref = graph.get_val(t1).toTensorRef();
ValueRef t1_vten = graph.add_tensor(t1_asref.sizes, t1_asref.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t1, t1_vten));
arg1 = t1_vten;
}
VK_CHECK_COND(graph.get_val(arg1).isTensor());
// Prepacking second arg (if needed)
ValueRef arg2 = t2;
if (graph.get_val(t2).isTensorRef()) {
TensorRef& t2_asref = graph.get_val(t2).toTensorRef();
ValueRef t2_vten = graph.add_tensor(t2_asref.sizes, t2_asref.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t2, t2_vten));
arg2 = t2_vten;
}
VK_CHECK_COND(graph.get_val(arg2).isTensor());
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);

graph.execute_nodes().emplace_back(
new ArithmeticNode(arg1, arg2, out, alpha, optype));
Expand All @@ -97,12 +93,12 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
}

ArithmeticNode::ArithmeticNode(
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype)
: ExecuteNode({t1, t2}, {out}), alpha_(alpha), optype_(optype) {}
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}

void ArithmeticNode::encode(ComputeGraph* graph) const {
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
Expand Down
12 changes: 6 additions & 6 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ DECLARE_OP_FN(pow);

ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const int64_t shared_object_idx = -1);

void add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
Expand All @@ -53,8 +53,8 @@ class ArithmeticPrepack : public virtual PrepackNode {
class ArithmeticNode : public virtual ExecuteNode {
public:
explicit ArithmeticNode(
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
Expand Down

0 comments on commit fae9ef0

Please sign in to comment.