Skip to content

Commit

Permalink
Merge ArithmeticNode into ExecuteNode (#2247)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Pull Request resolved: #2247

This diff moves the logic of `ArithmeticNode` into its corresponding OpFunction `add_arithmetic_node()` and the `ExecuteNode` class.

Our aim is to remove all derived classes of `ExecuteNode`, i.e., to make `ExecuteNode` a final class. All operator-specific logic will be handled in the OpFunction.

Note the next change will move `StagingNode` into its OpFunction + this new ExecuteNode implementation. Until then, we can't tidy up the `ExecuteNode` class fully. Finally, we leave a few task TODOs.
ghstack-source-id: 217439330
exported-using-ghexport

Reviewed By: SS-JIA

Differential Revision: D53982441

fbshipit-source-id: b8a51eee538b679e4168864a4870f3921c9ba333
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Mar 5, 2024
1 parent fae9ef0 commit 862f755
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 77 deletions.
51 changes: 51 additions & 0 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {

void ExecuteNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
api::PipelineBarrier pipeline_barrier{};

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph,
outputs_,
pipeline_barrier,
api::MemoryAccessType::WRITE,
descriptor_set,
idx);
idx = bind_values_to_descriptor_set(
graph,
inputs_,
pipeline_barrier,
api::MemoryAccessType::READ,
descriptor_set,
idx);
descriptor_set.bind(idx, params_.buffer());

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
}

} // namespace vulkan
} // namespace native
} // namespace at
27 changes: 22 additions & 5 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,37 @@ class ExecuteNode {

public:
ExecuteNode(ValueRef input, ValueRef output)
: inputs_{input}, outputs_{output} {}
: outputs_{output}, inputs_{input} {}

ExecuteNode(
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ValueRef>& outputs,
const std::vector<ValueRef>& inputs,
const std::vector<ValueRef>& outputs)
: inputs_(inputs), outputs_(outputs) {}
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
outputs_(outputs),
inputs_(inputs),
params_(std::move(params)) {}

virtual ~ExecuteNode() = default;

protected:
std::vector<ValueRef> inputs_;
// TODO: Consider making members const after we remove StagingNode.
api::ShaderInfo shader_;
api::utils::uvec3 global_workgroup_size_;
api::utils::uvec3 local_workgroup_size_;
std::vector<ValueRef> outputs_;
std::vector<ValueRef> inputs_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;

public:
virtual void encode(ComputeGraph* graph) const = 0;
virtual void encode(ComputeGraph* graph);
};

} // namespace vulkan
Expand Down
63 changes: 63 additions & 0 deletions backends/vulkan/runtime/graph/ops/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {

api::utils::ivec4 get_size_as_ivec4(const vTensor& t) {
return api::utils::make_ivec4(
{dim_at<Dim4D::Width>(t),
dim_at<Dim4D::Height>(t),
dim_at<Dim4D::Channel>(t),
dim_at<Dim4D::Batch>(t)});
}

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
if (tensor.buffer()) {
api::VulkanBuffer& buffer = tensor.buffer(
pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
descriptor_set.bind(idx, buffer);
} else {
api::VulkanImage& image =
tensor.image(pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
descriptor_set.bind(idx, image);
}
}

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& arg : args) {
Value& val = graph->get_val(arg);
if (val.isTensor()) {
vTensor& tensor = val.toTensor();
bind_tensor_to_descriptor_set(
tensor, pipeline_barrier, accessType, descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
}
}
return idx;
}

} // namespace vulkan
} // namespace native
} // namespace at
19 changes: 18 additions & 1 deletion backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/impl/Arithmetic.h>
#include <ATen/native/vulkan/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

Expand All @@ -21,6 +21,23 @@ namespace vulkan {
#define DECLARE_OP_FN(function) \
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args);

api::utils::ivec4 get_size_as_ivec4(const vTensor& t);

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

} // namespace vulkan
} // namespace native
} // namespace at
Expand Down
67 changes: 30 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,39 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

#include <ATen/native/vulkan/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

namespace at {
namespace native {
namespace vulkan {

#define DEFINE_ARITHMETIC_FN(function, op_type) \
#define DEFINE_ARITHMETIC_FN(function, shader) \
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_arithmetic_node( \
graph, \
args[0], \
args[1], \
args[2], \
arithmetic::OpType::op_type, \
args[3]); \
graph, args[0], args[1], args[2], VK_KERNEL(shader), args[3]); \
}

DEFINE_ARITHMETIC_FN(add, ADD);
DEFINE_ARITHMETIC_FN(sub, SUB);
DEFINE_ARITHMETIC_FN(mul, MUL);
DEFINE_ARITHMETIC_FN(div, DIV);
DEFINE_ARITHMETIC_FN(floor_div, FLOOR_DIV);
DEFINE_ARITHMETIC_FN(pow, POW);
DEFINE_ARITHMETIC_FN(add, add);
DEFINE_ARITHMETIC_FN(sub, sub);
DEFINE_ARITHMETIC_FN(mul, mul);
DEFINE_ARITHMETIC_FN(div, div);
DEFINE_ARITHMETIC_FN(floor_div, floor_divide);
DEFINE_ARITHMETIC_FN(pow, pow);

// TODO(T180908843): Bypass this entrypoint function by creating `ValueRef out`
// ahead of time.
ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const api::ShaderInfo& shader,
const int64_t shared_object_idx) {
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
api::ScalarType in1_dtype = graph.get_val_dtype(in1);

ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
add_arithmetic_node(graph, in1, in2, out, alpha, shader);
return out;
}

Expand All @@ -67,12 +62,27 @@ void add_arithmetic_node(
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype) {
const api::ShaderInfo& shader) {
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);

graph.execute_nodes().emplace_back(
new ArithmeticNode(arg1, arg2, out, alpha, optype));
vTensor& t_in1 = graph.get_val(arg1).toTensor();
vTensor& t_in2 = graph.get_val(arg2).toTensor();
vTensor& t_out = graph.get_val(out).toTensor();

api::utils::uvec3 global_size = t_out.extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

ArithmeticParams block{
get_size_as_ivec4(t_out),
get_size_as_ivec4(t_in1),
get_size_as_ivec4(t_in2),
1.0,
};
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
}

ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed)
Expand All @@ -92,23 +102,6 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
encode_copy_to_vtensor(graph->context(), staging, packed);
}

ArithmeticNode::ArithmeticNode(
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype)
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}

void ArithmeticNode::encode(ComputeGraph* graph) const {
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
vTensor& in2 = graph->get_val(inputs_[1]).toTensor();
vTensor& out = graph->get_val(outputs_[0]).toTensor();

api::ShaderInfo kernel = arithmetic::get_shader(optype_);
arithmetic::record_op(graph->context(), kernel, in1, in2, out, alpha_);
}

} // namespace vulkan
} // namespace native
} // namespace at
27 changes: 9 additions & 18 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ValueRef add_arithmetic_node(
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const api::ShaderInfo& shader,
const int64_t shared_object_idx = -1);

void add_arithmetic_node(
Expand All @@ -41,29 +41,20 @@ void add_arithmetic_node(
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
const api::ShaderInfo& shader);

class ArithmeticPrepack : public virtual PrepackNode {
public:
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);

void encode(ComputeGraph* graph) const override;
struct ArithmeticParams final {
api::utils::ivec4 outputSizes;
api::utils::ivec4 input1Sizes;
api::utils::ivec4 input2Sizes;
float alpha;
};

class ArithmeticNode : public virtual ExecuteNode {
class ArithmeticPrepack : public virtual PrepackNode {
public:
explicit ArithmeticNode(
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);

void encode(ComputeGraph* graph) const override;

private:
float alpha_;
arithmetic::OpType optype_;
};

} // namespace vulkan
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void encode_copy_from_vtensor(

StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {}

void StagingNode::encode(ComputeGraph* graph) const {
void StagingNode::encode(ComputeGraph* graph) {
Value& in_val = graph->get_val(inputs_[0]);
Value& out_val = graph->get_val(outputs_[0]);

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class StagingNode : public virtual ExecuteNode {
public:
explicit StagingNode(ValueRef from, ValueRef to);

void encode(ComputeGraph* graph) const override;
void encode(ComputeGraph* graph) override;
};

} // namespace vulkan
Expand Down
Loading

0 comments on commit 862f755

Please sign in to comment.