Skip to content

Commit

Permalink
[ET-VK][ez] Ensure descriptor set pools don't run out of memory
Browse files Browse the repository at this point in the history
## Context

While testing a toy model with a large number of operators, I ran into an issue on my local Pixel 6 Android device where the descriptor pool was running out of memory. This changeset implements a simple fix to ensure that descriptor pools do not run into this issue.

A longer term solution is to implement layout specific descriptor pools, but that is much more technically complex so go with this for now.

## Problem Details

#2285 made it so that `ComputeGraph` could tally up the total number of descriptors needed and size the descriptor pools appropriately, but it seems that this is not compatible with certain Vulkan drivers.

In the toy model, 1000 binary operators were added. Counting the descriptors required for the graph provides descriptor counts of

```
descriptorPoolMaxSets: 1255
descriptorUniformBufferCount: 5013
descriptorStorageBufferCount: 4
descriptorCombinedSamplerCount: 2504
descriptorStorageImageCount: 1254
```

Which appears to be correct, however it appears that the descriptor pool runs out of memory due to an insufficient number of `descriptorStorageBufferCount`. The `descriptorStorageBufferCount` needs to be set at a surprisingly high number (approx ~1000) before the descriptor pool does not run out of memory. I'm not sure exactly what causes this behaviour, but it could be due to the implementation details of the driver.

## Solution

Ensure that all descriptor counts are at greater than or equal to the maximum number of descriptor sets seems to work. Implement this as a temporary solution.

Differential Revision: [D54853788](https://our.internmc.facebook.com/intern/diff/D54853788/)

ghstack-source-id: 218502788
Pull Request resolved: #2398
  • Loading branch information
SS-JIA committed Mar 13, 2024
1 parent 9ff2c0e commit 66310e7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 7 deletions.
11 changes: 6 additions & 5 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ void ComputeGraph::prepare() {
prepack_descriptor_counts_.field) * \
config_.descriptorPoolSafetyFactor))

uint32_t max_sets = MERGE_FIELD(descriptorPoolMaxSets);
api::DescriptorPoolConfig config{
MERGE_FIELD(descriptorPoolMaxSets),
MERGE_FIELD(descriptorUniformBufferCount),
MERGE_FIELD(descriptorStorageBufferCount),
MERGE_FIELD(descriptorCombinedSamplerCount),
MERGE_FIELD(descriptorStorageImageCount),
max_sets,
std::max(MERGE_FIELD(descriptorUniformBufferCount), max_sets),
std::max(MERGE_FIELD(descriptorStorageBufferCount), max_sets),
std::max(MERGE_FIELD(descriptorCombinedSamplerCount), max_sets),
std::max(MERGE_FIELD(descriptorStorageImageCount), max_sets),
1u,
};

Expand Down
58 changes: 56 additions & 2 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,64 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
EXTRACT_TENSOR(out);

// Sanity check that the values are correct
int i = 0;
for (const auto& val : data_out) {
ASSERT_TRUE(val == val_out);
++i;
}
}
}

TEST(VulkanComputeGraphTest, test_large_graph) {
GraphConfig config;
ComputeGraph graph(config);

int64_t input_w = 256;
int64_t input_h = 256;
int64_t input_c = 8;

std::vector<int64_t> size_big = {input_c, input_h, input_w};
std::vector<int64_t> size_small = {input_c, input_h, 1};

// Build graph

IOValueRef a = graph.add_input_tensor(size_big, api::kFloat, 2);
IOValueRef b = graph.add_input_tensor(size_small, api::kFloat, 4);

ValueRef c = graph.add_tensor(size_big, api::kFloat, 6);

auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, b.value, kDummyValueRef, c});

int n = 100;

for (int i=0; i<n;i ++ ){

addFn(graph, {c, b.value, kDummyValueRef, a.value});

addFn(graph, {a.value, b.value, kDummyValueRef, c});
}

IOValueRef out = {};
out.value = c;
out.staging = graph.set_output_tensor(out.value);

graph.prepare();
graph.encode_execute();

for (int i = 0; i < 10; i++) {
float val_a = 1.0f;
float val_b = 2.0f;

float val_e = val_a + val_b * (2 * n + 1);

fill_vtensor(graph, a, val_a);
fill_vtensor(graph, b, val_b);

graph.execute();

EXTRACT_TENSOR(out);

for (const auto& val : data_out) {
EXPECT_TRUE(val == val_e);
}
}
}

0 comments on commit 66310e7

Please sign in to comment.