Skip to content

Commit

Permalink
[ET-VK][Op Redesign][7/n] Generalize ExecuteNode args with ArgGroup
Browse files Browse the repository at this point in the history
Leftover from Op Redesign 5/n - D54445787.

---

Typically, we specify outputs first and inputs second in the shader layout, but not always. In `image_to_nchw.glsl`, this is flipped:
https://www.internalfb.com/code/fbsource/[d303d229f22616bfba32e5bb5d4d27dc656f41a7]/fbcode/caffe2/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl?lines=8-19

Hence, we generalize our `ExecuteNode` specification to take a vector of args (image, buffer, etc.), with specification of access type. Since typically we will group args of the same access together, we correspond one access specification to multiple args.

We reuse `api::MemoryAccessType` for access specification.

Differential Revision: [D54518840](https://our.internmc.facebook.com/intern/diff/D54518840/)

ghstack-source-id: 217489419
Pull Request resolved: #2262
  • Loading branch information
jorgep31415 committed Mar 5, 2024
1 parent 91c3d65 commit d445a34
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 36 deletions.
14 changes: 1 addition & 13 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph,
outputs_,
pipeline_barrier,
api::MemoryAccessType::WRITE,
descriptor_set,
idx);
idx = bind_values_to_descriptor_set(
graph,
inputs_,
pipeline_barrier,
api::MemoryAccessType::READ,
descriptor_set,
idx);
graph, args_, pipeline_barrier, descriptor_set, idx);
descriptor_set.bind(idx, params_.buffer());

context->register_shader_dispatch(
Expand Down
26 changes: 20 additions & 6 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ namespace vulkan {

class ComputeGraph;

/*
* Represents a group of shader arguments (images and/or buffers), with a common
* access permission.
*/
struct ArgGroup {
ArgGroup(const ValueRef ref, const api::MemoryAccessType access)
: refs{ref}, access(access) {}

ArgGroup(
const std::vector<ValueRef>& refs,
const api::MemoryAccessType access)
: refs(refs), access(access) {}

const std::vector<ValueRef> refs;
const api::MemoryAccessType access;
};

/*
* Represents a single execution op in a ML model. In graph mode, ops will be
* implemented in a derived class that implements encode, which will implement
Expand All @@ -36,14 +53,12 @@ class ExecuteNode final {
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ValueRef>& outputs,
const std::vector<ValueRef>& inputs,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
outputs_(outputs),
inputs_(inputs),
args_(args),
params_(std::move(params)) {}

~ExecuteNode() = default;
Expand All @@ -54,8 +69,7 @@ class ExecuteNode final {
const api::ShaderInfo shader_;
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const std::vector<ValueRef> outputs_;
const std::vector<ValueRef> inputs_;
const std::vector<ArgGroup> args_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
Expand Down
25 changes: 15 additions & 10 deletions backends/vulkan/runtime/graph/ops/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,26 @@ void bind_staging_to_descriptor_set(

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& arg : args) {
Value& val = graph->get_val(arg);
if (val.isTensor()) {
bind_tensor_to_descriptor_set(
val.toTensor(), pipeline_barrier, accessType, descriptor_set, idx++);
} else if (val.isStaging()) {
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
for (auto& ref : arg.refs) {
Value& val = graph->get_val(ref);
if (val.isTensor()) {
bind_tensor_to_descriptor_set(
val.toTensor(),
pipeline_barrier,
arg.access,
descriptor_set,
idx++);
} else if (val.isStaging()) {
bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++);
} else {
VK_THROW("Unsupported type: ", val.type());
}
}
}
return idx;
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ void bind_staging_to_descriptor_set(

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ValueRef>& args,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ void add_arithmetic_node(
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params)));
shader,
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE},
{{arg1, arg2}, api::MemoryAccessType::READ}},
std::move(params)));
}

} // namespace vulkan
Expand Down
8 changes: 4 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void add_staging_to_tensor_node(
shader,
global_size,
local_size,
{out_tensor},
{in_staging},
{{out_tensor, api::MemoryAccessType::WRITE},
{in_staging, api::MemoryAccessType::READ}},
std::move(params)));
}

Expand Down Expand Up @@ -94,8 +94,8 @@ void add_tensor_to_staging_node(
shader,
global_size,
local_size,
{in_tensor},
{out_staging},
{{in_tensor, api::MemoryAccessType::READ},
{out_staging, api::MemoryAccessType::WRITE}},
std::move(params)));
}

Expand Down

0 comments on commit d445a34

Please sign in to comment.