Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten _To_Copy #6055

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __contains__(self, op):
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten._to_copy.default,
# Matrix Multiplication
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.mm.default,
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:

return False

def is_valid_to_copy(self, node: torch.fx.node) -> bool:
# lower only if floating point dtype conversion
return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16)

def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
Expand Down Expand Up @@ -172,6 +176,9 @@ def _is_node_supported(

features = VulkanSupportedOperators._ops[target]

if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy(node):
return False

if self.require_dynamic_shapes and not features.supports_dynamic_shape:
return False

Expand Down
48 changes: 48 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
#include <set>

namespace vkcompute {

void resize_to_copy_op_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr self = graph->get_tensor(args[1].refs[0]);

out->virtual_resize(self->sizes());
}

void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
static std::set<vkapi::ScalarType> supported_types = {
vkapi::ScalarType::Float, vkapi::ScalarType::Half};

VK_CHECK_COND(
supported_types.find(graph.dtype_of(in)) != supported_types.end() &&
supported_types.find(graph.dtype_of(out)) != supported_types.end(),
"Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", vkapi::to_string(graph.dtype_of(in)), " <-> ", vkapi::to_string(graph.dtype_of(out)));

graph.execute_nodes().emplace_back(
new BlitNode(graph, prepack_if_tensor_ref(graph, in), out));
}

void to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_to_copy_node(graph, args[0], args[7]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten._to_copy.default, to_copy);
}
} // namespace vkcompute
108 changes: 108 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <gtest/gtest.h>

#include <bitset>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -3251,3 +3252,110 @@ TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
test_transpose_view_mm(2, 7, 17, 5, storage_type);
}
}

void test_to_copy() {
GraphConfig config;
config.set_storage_type_override(utils::kTexture3D);
ComputeGraph graph(config);
int M = 8;
int N = 8;
int K = 8;
// Build graph
IOValueRef in = graph.add_input_tensor(
{1, M, N, K},
vkapi::kFloat,
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);

std::vector<float> data_in =
create_random_float_buffer(M * N * K, -1024, 1024);
graph.copy_into_staging(in.staging, data_in.data(), data_in.size());

IOValueRef out;
out.value = graph.add_tensor(
{1, M, N, K},
vkapi::kHalf,
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);

auto op = VK_GET_OP_FN("aten._to_copy.default");
op(graph,
{in.value,
graph.add_none(),
graph.add_none(),
graph.add_none(),
graph.add_none(),
graph.add_none(),
graph.add_none(),
out.value});

out.staging = graph.set_output_tensor(out.value);

graph.prepare();
graph.encode_prepack();
graph.prepack();
graph.encode_execute();
graph.propagate_resize();
graph.execute();

std::vector<torch::executor::Half> output_data(graph.numel_of(out.value));
graph.copy_from_staging(out.staging, output_data.data(), output_data.size());

EXPECT_EQ(data_in.size(), output_data.size());

float mse_ex = 0.0f;
float mse_vk = 0.0f;

// check results
for (size_t i = 0; i < output_data.size(); ++i) {
float input = data_in[i];
torch::executor::Half expected_output =
static_cast<torch::executor::Half>(input);
uint16_t* expected_bits = reinterpret_cast<uint16_t*>(&expected_output);
torch::executor::Half output = output_data[i];
uint16_t* output_bits = reinterpret_cast<uint16_t*>(&output);

std::string msg;
msg.reserve(64);
msg = "input = " + std::to_string(input) + "(0b" +
std::bitset<32>(*reinterpret_cast<uint32_t*>(&input)).to_string() +
"), expected output = " + std::to_string(expected_output) + "(0b" +
std::bitset<16>(*expected_bits).to_string() +
"), recieved output = " + std::to_string(output) + "(0b" +
std::bitset<16>(*output_bits).to_string() + ")";

std::cout << msg << std::endl;

// Note: Torch executor half "rounds up" when converting to fp16 whereas
// most driver implementations of Vulkan's opFConvert() just truncates the
// extra bits for performance (rounding introduces conditional).
// Example:
// INPUT F32 = 25.248 (sign{0b0}, exp{0b10000011},
// mantissa{0b10010011111101111100111}),
// TORCH HALF OUTPUT F16 = 25.25 (sign{0b0}, exp{0b10011},
// mantissa{0b1001010000}),
// VULKAN OUTPUT F16 = 25.2344 (sign{0b0}, exp{0b10011},
// mantissa{0b1001001111})
// Note:
// The vulkan mantissa exactly matches the first 10
// bits of the input 23 bit mantissa. But since the 11th bit is 1, the
// torch half output is rounded up (essentially adding a 1).
// Vulkan mantissa{0b1001001111} + 1 = Torch half mantissa{0b1001010000}

EXPECT_TRUE(
(*output_bits == *expected_bits) ||
/*rounding error*/ ((*output_bits + 1u) == *expected_bits));
mse_ex += std::pow(expected_output - input, 2);
mse_vk += std::pow(output - input, 2);
}

mse_ex /= output_data.size();
mse_vk /= output_data.size();
std::cout << "========================================================="
<< std::endl;
std::cout << "mse_ex = " << mse_ex << ", mse_vk = " << mse_vk << std::endl;
}

TEST(VulkanComputeGraphOpsTest, test_to_copy) {
if (context()->adapter_ptr()->has_16bit_storage()) {
test_to_copy();
}
}
Loading