Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lazy descriptor pool allocation #2285

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,39 +62,6 @@ api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
}
}

GraphConfig generate_config() {
const uint32_t submit_frequency = UINT32_MAX;

const api::CommandPoolConfig cmd_config{
4u, // cmdPoolInitialSize
2u, // cmdPoolBatchSize
};

const api::DescriptorPoolConfig descriptor_pool_config{
1024u, // descriptorPoolMaxSets
1024u, // descriptorUniformBufferCount
1024u, // descriptorStorageBufferCount
1024u, // descriptorCombinedSamplerCount
1024u, // descriptorStorageImageCount
32u, // descriptorPileSizes
};

const api::QueryPoolConfig query_pool_config{};

const api::ContextConfig context_config{
submit_frequency, // cmdSubmitFrequency
cmd_config, // cmdPoolConfig
descriptor_pool_config, // descriptorPoolConfig
query_pool_config, // queryPoolConfig
};

const GraphConfig graph_config{
context_config,
};

return graph_config;
}

class GraphBuilder {
ComputeGraph* compute_graph_;
VkGraphPtr flatbuffer_;
Expand Down Expand Up @@ -269,6 +236,8 @@ class VulkanBackend final : public PyTorchBackendInterface {

builder.build_graph();

compute_graph->prepare();

compute_graph->encode_prepack();
compute_graph->prepack();

Expand All @@ -284,7 +253,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
context.get_runtime_allocator(), ComputeGraph);

new (compute_graph) ComputeGraph(generate_config());
new (compute_graph) ComputeGraph(GraphConfig());

Error err = compileModel(processed->data(), compute_graph);

Expand Down
65 changes: 65 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace vulkan {

ComputeGraph::ComputeGraph(GraphConfig config)
: config_{config},
prepack_descriptor_counts_{},
execute_descriptor_counts_{},
context_{new api::Context(
api::runtime()->default_adapter_i(),
config_.contextConfig)},
Expand All @@ -27,6 +29,19 @@ ComputeGraph::ComputeGraph(GraphConfig config)
execute_nodes_{},
inputs_{},
outputs_{} {
// Ensure that descriptor counts are initialized to 0
prepack_descriptor_counts_.descriptorPoolMaxSets = 0;
prepack_descriptor_counts_.descriptorUniformBufferCount = 0;
prepack_descriptor_counts_.descriptorStorageBufferCount = 0;
prepack_descriptor_counts_.descriptorCombinedSamplerCount = 0;
prepack_descriptor_counts_.descriptorStorageImageCount = 0;

execute_descriptor_counts_.descriptorPoolMaxSets = 0;
execute_descriptor_counts_.descriptorUniformBufferCount = 0;
execute_descriptor_counts_.descriptorStorageBufferCount = 0;
execute_descriptor_counts_.descriptorCombinedSamplerCount = 0;
execute_descriptor_counts_.descriptorStorageImageCount = 0;

context_->set_cmd(/*reusable = */ true);
}

Expand All @@ -39,6 +54,33 @@ ComputeGraph::~ComputeGraph() {
context_->flush();
}

void ComputeGraph::update_descriptor_counts(
const api::ShaderInfo& shader_info,
bool execute) {
api::DescriptorPoolConfig* config =
execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_;

config->descriptorPoolMaxSets += 1;
for (const VkDescriptorType arg_type : shader_info.kernel_layout) {
switch (arg_type) {
case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
config->descriptorUniformBufferCount += 1;
break;
case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
config->descriptorStorageBufferCount += 1;
break;
case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
config->descriptorCombinedSamplerCount += 1;
break;
case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
config->descriptorStorageImageCount += 1;
break;
default:
VK_THROW("Unsupported descriptor type!");
}
}
}

ValueRef ComputeGraph::add_tensor(
const std::vector<int64_t>& sizes,
const api::ScalarType dtype,
Expand Down Expand Up @@ -138,6 +180,29 @@ void ComputeGraph::copy_from_staging(
copy_staging_to_ptr(staging, data, nbytes);
}

void ComputeGraph::prepare() {
#define MERGE_FIELD(field) \
static_cast<uint32_t>(std::ceil( \
std::max( \
execute_descriptor_counts_.field, \
prepack_descriptor_counts_.field) * \
config_.descriptorPoolSafetyFactor))

api::DescriptorPoolConfig config{
MERGE_FIELD(descriptorPoolMaxSets),
MERGE_FIELD(descriptorUniformBufferCount),
MERGE_FIELD(descriptorStorageBufferCount),
MERGE_FIELD(descriptorCombinedSamplerCount),
MERGE_FIELD(descriptorStorageImageCount),
1u,
};

if (!context_->descriptor_pool()) {
context_->descriptor_pool().init(config);
}
#undef MERGE_FIELD
}

void ComputeGraph::encode_prepack() {
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
node->encode(this);
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class ComputeGraph final {

private:
GraphConfig config_;
api::DescriptorPoolConfig prepack_descriptor_counts_;
api::DescriptorPoolConfig execute_descriptor_counts_;

std::unique_ptr<api::Context> context_;
std::vector<SharedObject> shared_objects_;
std::vector<Value> values_;
Expand Down Expand Up @@ -87,6 +90,10 @@ class ComputeGraph final {
return outputs_;
}

void update_descriptor_counts(
const api::ShaderInfo& shader_info,
bool execute);

/*
* Returns the value at a particular reference
*/
Expand Down Expand Up @@ -163,6 +170,12 @@ class ComputeGraph final {

SharedObject& get_shared_object(const int64_t idx);

//
// Graph Preparation
//

void prepare();

//
// Input/Output
//
Expand Down
56 changes: 56 additions & 0 deletions backends/vulkan/runtime/graph/GraphConfig.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>

namespace at {
namespace native {
namespace vulkan {

GraphConfig::GraphConfig() {
// No automatic submissions
const uint32_t submit_frequency = UINT32_MAX;

// Only one command buffer will be encoded at a time
const api::CommandPoolConfig cmd_config{
1u, // cmdPoolInitialSize
1u, // cmdPoolBatchSize
};

// Use lazy descriptor pool initialization by default; the graph runtime will
// tally up the number of descriptor sets needed while building the graph and
// trigger descriptor pool initialization with exact sizes before encoding the
// command buffer.
const api::DescriptorPoolConfig descriptor_pool_config{
0u, // descriptorPoolMaxSets
0u, // descriptorUniformBufferCount
0u, // descriptorStorageBufferCount
0u, // descriptorCombinedSamplerCount
0u, // descriptorStorageImageCount
0u, // descriptorPileSizes
};

const api::QueryPoolConfig query_pool_config{};

const api::ContextConfig context_config{
submit_frequency, // cmdSubmitFrequency
cmd_config, // cmdPoolConfig
descriptor_pool_config, // descriptorPoolConfig
query_pool_config, // queryPoolConfig
};

contextConfig = context_config;

// Empirically selected safety factor. If descriptor pools start running out
// of memory, increase this safety factor.
descriptorPoolSafetyFactor = 1.25;
}

} // namespace vulkan
} // namespace native
} // namespace at
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/GraphConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ namespace vulkan {

struct GraphConfig final {
api::ContextConfig contextConfig;

// Creating a descriptor pool with exactly the number of descriptors tallied
// by iterating through the shader layouts of shaders used in the graph risks
// the descriptor pool running out of memory, therefore apply a safety factor
// to descriptor counts when creating the descriptor pool to mitigate this
// risk.
float descriptorPoolSafetyFactor;

// Generate a default graph config with pre-configured settings
explicit GraphConfig();
};

} // namespace vulkan
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ namespace at {
namespace native {
namespace vulkan {

ExecuteNode::ExecuteNode(
ComputeGraph& graph,
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(std::move(params)) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

void ExecuteNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
api::PipelineBarrier pipeline_barrier{};
Expand Down
8 changes: 2 additions & 6 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,12 @@ class ExecuteNode final {

public:
ExecuteNode(
ComputeGraph& graph,
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(std::move(params)) {}
api::UniformParamsBuffer&& params);

~ExecuteNode() = default;

Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ namespace at {
namespace native {
namespace vulkan {

PrepackNode::PrepackNode(
ComputeGraph& graph,
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(std::move(params)) {
graph.update_descriptor_counts(shader, /*execute = */ false);
}

void PrepackNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
api::PipelineBarrier pipeline_barrier{};
Expand Down
9 changes: 2 additions & 7 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,13 @@ class PrepackNode final {

public:
PrepackNode(
ComputeGraph& graph,
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(std::move(params)) {}
api::UniformParamsBuffer&& params);

~PrepackNode() = default;

Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void add_arithmetic_node(
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
shader,
global_size,
local_size,
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void add_staging_to_tensor_node(
graph.context(), create_staging_params(t_out));

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
shader,
global_size,
local_size,
Expand Down Expand Up @@ -90,6 +91,7 @@ void add_tensor_to_staging_node(
}

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
shader,
global_size,
local_size,
Expand All @@ -112,7 +114,7 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
api::UniformParamsBuffer params(graph.context(), sp);

graph.prepack_nodes().emplace_back(new PrepackNode(
shader, global_size, local_size, vref, v, std::move(params)));
graph, shader, global_size, local_size, vref, v, std::move(params)));

return v;
}
Expand Down
Loading
Loading