diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 58855c03fa..a4d93ef55e 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -202,6 +202,19 @@ def forward(self, x, y, w): self.lower_module_and_test_output(add_module, sample_inputs) + def test_vulkan_backend_zero_dim_tensor(self): + class ZeroDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.zero = torch.full([], 1.3, dtype=torch.float32) + + def forward(self, x): + return x + self.zero + + internal_data_module = ZeroDimModule() + sample_inputs = (torch.rand(size=(2, 3), dtype=torch.float32),) + self.lower_module_and_test_output(internal_data_module, sample_inputs) + def test_vulkan_backend_internal_data(self): class InternalDataModule(torch.nn.Module): def __init__(self): diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 939367851e..dd926d3470 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -627,6 +627,51 @@ TEST(VulkanComputeGraphTest, test_values_string) { EXPECT_TRUE(stored == "hello, world"); } +TEST(VulkanComputeGraphTest, test_zero_dim_tensor) { + GraphConfig config; + ComputeGraph graph(config); + + std::vector size_big = {7, 3, 5}; + std::vector size_small = {}; + + // Build graph + + IOValueRef a = graph.add_input_tensor(size_big, api::kFloat); + IOValueRef b = graph.add_input_tensor(size_small, api::kFloat); + + IOValueRef out = {}; + + out.value = graph.add_tensor(size_big, api::kFloat); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, out.value}); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + for (float i = 5.0f; i < 30.0f; i += 10.0f) { + float val_a = i + 2.0f; + float val_b = i + 1.5f; + float val_c = val_a + val_b; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + CHECK_VALUE(data_out, i, val_c); + } + } +} + TEST(VulkanComputeGraphTest, test_simple_graph) { GraphConfig config; ComputeGraph graph(config);