Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK][Op Redesign][7/n] Generalize ExecuteNode args with ArgGroup #2262

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph,
outputs_,
pipeline_barrier,
api::MemoryAccessType::WRITE,
descriptor_set,
idx);
idx = bind_values_to_descriptor_set(
graph,
inputs_,
pipeline_barrier,
api::MemoryAccessType::READ,
descriptor_set,
idx);
graph, args_, pipeline_barrier, descriptor_set, idx);
descriptor_set.bind(idx, params_.buffer());

context->register_shader_dispatch(
Expand Down
26 changes: 20 additions & 6 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ namespace vulkan {

class ComputeGraph;

/*
* Represents a group of shader arguments (images and/or buffers), with a common
* access permission.
*/
struct ArgGroup {
ArgGroup(const ValueRef ref, const api::MemoryAccessType access)
: refs{ref}, access(access) {}

ArgGroup(
const std::vector<ValueRef>& refs,
const api::MemoryAccessType access)
: refs(refs), access(access) {}

const std::vector<ValueRef> refs;
const api::MemoryAccessType access;
};

/*
* Represents a single execution op in a ML model. In graph mode, ops will be
* implemented in a derived class that implements encode, which will implement
Expand All @@ -36,14 +53,12 @@ class ExecuteNode final {
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ValueRef>& outputs,
const std::vector<ValueRef>& inputs,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
outputs_(outputs),
inputs_(inputs),
args_(args),
params_(std::move(params)) {}

~ExecuteNode() = default;
Expand All @@ -54,8 +69,7 @@ class ExecuteNode final {
const api::ShaderInfo shader_;
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const std::vector<ValueRef> outputs_;
const std::vector<ValueRef> inputs_;
const std::vector<ArgGroup> args_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
Expand Down
25 changes: 15 additions & 10 deletions backends/vulkan/runtime/graph/ops/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,26 @@ void bind_staging_to_descriptor_set(

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& arg : args) {
Value& val = graph->get_val(arg);
if (val.isTensor()) {
bind_tensor_to_descriptor_set(
val.toTensor(), pipeline_barrier, accessType, descriptor_set, idx++);
} else if (val.isStaging()) {
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
for (auto& ref : arg.refs) {
Value& val = graph->get_val(ref);
if (val.isTensor()) {
bind_tensor_to_descriptor_set(
val.toTensor(),
pipeline_barrier,
arg.access,
descriptor_set,
idx++);
} else if (val.isStaging()) {
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
}
}
}
return idx;
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ void bind_staging_to_descriptor_set(

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ void add_arithmetic_node(
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
shader,
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE},
{{arg1, arg2}, api::MemoryAccessType::READ}},
std::move(params)));
}

} // namespace vulkan
Expand Down
8 changes: 4 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void add_staging_to_tensor_node(
shader,
global_size,
local_size,
{out_tensor},
{in_staging},
{{out_tensor, api::MemoryAccessType::WRITE},
{in_staging, api::MemoryAccessType::READ}},
std::move(params)));
}

Expand Down Expand Up @@ -94,8 +94,8 @@ void add_tensor_to_staging_node(
shader,
global_size,
local_size,
{in_tensor},
{out_staging},
{{in_tensor, api::MemoryAccessType::READ},
{out_staging, api::MemoryAccessType::WRITE}},
std::move(params)));
}

Expand Down
Loading