Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK][Op Redesign][4/n] Merge ArithmeticNode into ExecuteNode #2247

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {

void ExecuteNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
api::PipelineBarrier pipeline_barrier{};

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph,
outputs_,
pipeline_barrier,
api::MemoryAccessType::WRITE,
descriptor_set,
idx);
idx = bind_values_to_descriptor_set(
graph,
inputs_,
pipeline_barrier,
api::MemoryAccessType::READ,
descriptor_set,
idx);
descriptor_set.bind(idx, params_.buffer());

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
}

} // namespace vulkan
} // namespace native
} // namespace at
27 changes: 22 additions & 5 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,37 @@ class ExecuteNode {

public:
ExecuteNode(ValueRef input, ValueRef output)
: inputs_{input}, outputs_{output} {}
: outputs_{output}, inputs_{input} {}

ExecuteNode(
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ValueRef>& outputs,
const std::vector<ValueRef>& inputs,
const std::vector<ValueRef>& outputs)
: inputs_(inputs), outputs_(outputs) {}
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
outputs_(outputs),
inputs_(inputs),
params_(std::move(params)) {}

virtual ~ExecuteNode() = default;

protected:
std::vector<ValueRef> inputs_;
// TODO: Consider making members const after we remove StagingNode.
api::ShaderInfo shader_;
api::utils::uvec3 global_workgroup_size_;
api::utils::uvec3 local_workgroup_size_;
std::vector<ValueRef> outputs_;
std::vector<ValueRef> inputs_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;

public:
virtual void encode(ComputeGraph* graph) const = 0;
virtual void encode(ComputeGraph* graph);
};

} // namespace vulkan
Expand Down
63 changes: 63 additions & 0 deletions backends/vulkan/runtime/graph/ops/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {

api::utils::ivec4 get_size_as_ivec4(const vTensor& t) {
return api::utils::make_ivec4(
{dim_at<Dim4D::Width>(t),
dim_at<Dim4D::Height>(t),
dim_at<Dim4D::Channel>(t),
dim_at<Dim4D::Batch>(t)});
}

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
if (tensor.buffer()) {
api::VulkanBuffer& buffer = tensor.buffer(
pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
descriptor_set.bind(idx, buffer);
} else {
api::VulkanImage& image =
tensor.image(pipeline_barrier, api::PipelineStage::COMPUTE, accessType);
descriptor_set.bind(idx, image);
}
}

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& arg : args) {
Value& val = graph->get_val(arg);
if (val.isTensor()) {
vTensor& tensor = val.toTensor();
bind_tensor_to_descriptor_set(
tensor, pipeline_barrier, accessType, descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
}
}
return idx;
}

} // namespace vulkan
} // namespace native
} // namespace at
19 changes: 18 additions & 1 deletion backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/impl/Arithmetic.h>
#include <ATen/native/vulkan/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

Expand All @@ -21,6 +21,23 @@ namespace vulkan {
#define DECLARE_OP_FN(function) \
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args);

api::utils::ivec4 get_size_as_ivec4(const vTensor& t);

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

} // namespace vulkan
} // namespace native
} // namespace at
Expand Down
67 changes: 30 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,39 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

#include <ATen/native/vulkan/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

namespace at {
namespace native {
namespace vulkan {

#define DEFINE_ARITHMETIC_FN(function, op_type) \
#define DEFINE_ARITHMETIC_FN(function, shader) \
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_arithmetic_node( \
graph, \
args[0], \
args[1], \
args[2], \
arithmetic::OpType::op_type, \
args[3]); \
graph, args[0], args[1], args[2], VK_KERNEL(shader), args[3]); \
}

DEFINE_ARITHMETIC_FN(add, ADD);
DEFINE_ARITHMETIC_FN(sub, SUB);
DEFINE_ARITHMETIC_FN(mul, MUL);
DEFINE_ARITHMETIC_FN(div, DIV);
DEFINE_ARITHMETIC_FN(floor_div, FLOOR_DIV);
DEFINE_ARITHMETIC_FN(pow, POW);
DEFINE_ARITHMETIC_FN(add, add);
DEFINE_ARITHMETIC_FN(sub, sub);
DEFINE_ARITHMETIC_FN(mul, mul);
DEFINE_ARITHMETIC_FN(div, div);
DEFINE_ARITHMETIC_FN(floor_div, floor_divide);
DEFINE_ARITHMETIC_FN(pow, pow);

// TODO(T180908843): Bypass this entrypoint function by creating `ValueRef out`
// ahead of time.
ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const api::ShaderInfo& shader,
const int64_t shared_object_idx) {
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
api::ScalarType in1_dtype = graph.get_val_dtype(in1);

ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
add_arithmetic_node(graph, in1, in2, out, alpha, shader);
return out;
}

Expand All @@ -67,12 +62,27 @@ void add_arithmetic_node(
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype) {
const api::ShaderInfo& shader) {
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);

graph.execute_nodes().emplace_back(
new ArithmeticNode(arg1, arg2, out, alpha, optype));
vTensor& t_in1 = graph.get_val(arg1).toTensor();
vTensor& t_in2 = graph.get_val(arg2).toTensor();
vTensor& t_out = graph.get_val(out).toTensor();

api::utils::uvec3 global_size = t_out.extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

ArithmeticParams block{
get_size_as_ivec4(t_out),
get_size_as_ivec4(t_in1),
get_size_as_ivec4(t_in2),
1.0,
};
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
}

ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed)
Expand All @@ -92,23 +102,6 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
encode_copy_to_vtensor(graph->context(), staging, packed);
}

ArithmeticNode::ArithmeticNode(
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype)
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}

void ArithmeticNode::encode(ComputeGraph* graph) const {
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
vTensor& in2 = graph->get_val(inputs_[1]).toTensor();
vTensor& out = graph->get_val(outputs_[0]).toTensor();

api::ShaderInfo kernel = arithmetic::get_shader(optype_);
arithmetic::record_op(graph->context(), kernel, in1, in2, out, alpha_);
}

} // namespace vulkan
} // namespace native
} // namespace at
27 changes: 9 additions & 18 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ValueRef add_arithmetic_node(
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const api::ShaderInfo& shader,
const int64_t shared_object_idx = -1);

void add_arithmetic_node(
Expand All @@ -41,29 +41,20 @@ void add_arithmetic_node(
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
const api::ShaderInfo& shader);

class ArithmeticPrepack : public virtual PrepackNode {
public:
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);

void encode(ComputeGraph* graph) const override;
struct ArithmeticParams final {
api::utils::ivec4 outputSizes;
api::utils::ivec4 input1Sizes;
api::utils::ivec4 input2Sizes;
float alpha;
};

class ArithmeticNode : public virtual ExecuteNode {
class ArithmeticPrepack : public virtual PrepackNode {
public:
explicit ArithmeticNode(
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);

void encode(ComputeGraph* graph) const override;

private:
float alpha_;
arithmetic::OpType optype_;
};

} // namespace vulkan
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void encode_copy_from_vtensor(

StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {}

void StagingNode::encode(ComputeGraph* graph) const {
void StagingNode::encode(ComputeGraph* graph) {
Value& in_val = graph->get_val(inputs_[0]);
Value& out_val = graph->get_val(outputs_[0]);

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class StagingNode : public virtual ExecuteNode {
public:
explicit StagingNode(ValueRef from, ValueRef to);

void encode(ComputeGraph* graph) const override;
void encode(ComputeGraph* graph) override;
};

} // namespace vulkan
Expand Down
Loading
Loading