From 45d20f736e31fdc14ade3f917794fa5187dc1cc6 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Wed, 29 Nov 2023 16:30:54 -0800 Subject: [PATCH] Add support for PyTorch style printing of output tensors --- .../mps/executor_runner/mps_executor_runner.mm | 15 ++++++--------- examples/apple/mps/executor_runner/targets.bzl | 1 + 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/apple/mps/executor_runner/mps_executor_runner.mm b/examples/apple/mps/executor_runner/mps_executor_runner.mm index 379bd42282..ef37beab0e 100644 --- a/examples/apple/mps/executor_runner/mps_executor_runner.mm +++ b/examples/apple/mps/executor_runner/mps_executor_runner.mm @@ -17,6 +17,7 @@ #import #include +#include #include @@ -31,6 +32,7 @@ #include #include #include +#include #include using namespace std::chrono; @@ -440,15 +442,10 @@ MemoryManager memory_manager( std::vector outputs(method->outputs_size()); status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); - for (EValue& output : outputs) { - // TODO(T159700776): This assumes that all outputs are fp32 tensors. Add - // support for other EValues and Tensor dtypes, and print tensors in a more - // readable way. - auto output_tensor = output.toTensor(); - auto data_output = output_tensor.const_data_ptr(); - for (size_t j = 0; j < output_tensor.numel(); ++j) { - ET_LOG(Info, "%f", data_output[j]); - } + // Print the first and last 100 elements of long lists of scalars. + std::cout << torch::executor::util::evalue_edge_items(100); + for (int i = 0; i < outputs.size(); ++i) { + std::cout << "Output " << i << ": " << outputs[i] << std::endl; } // Dump the profiling data to the specified file. diff --git a/examples/apple/mps/executor_runner/targets.bzl b/examples/apple/mps/executor_runner/targets.bzl index 21ee3373f9..48c4dc3fcf 100644 --- a/examples/apple/mps/executor_runner/targets.bzl +++ b/examples/apple/mps/executor_runner/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): deps = [ "//executorch/backends/apple/mps/runtime:MPSBackend", "//executorch/runtime/executor:program", + "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/data_loader:file_data_loader", "//executorch/kernels/portable:generated_lib_all_ops", "//executorch/extension/data_loader:file_data_loader",