Skip to content

Commit

Permalink
Add support for PyTorch style printing of output tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Nov 30, 2023
1 parent 9853d93 commit 45d20f7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
15 changes: 6 additions & 9 deletions examples/apple/mps/executor_runner/mps_executor_runner.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>

#include <memory>
#include <iostream>

#include <gflags/gflags.h>

Expand All @@ -31,6 +32,7 @@
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/extension/evalue_util/print_evalue.h>

#include <chrono>
using namespace std::chrono;
Expand Down Expand Up @@ -440,15 +442,10 @@ MemoryManager memory_manager(
std::vector<EValue> outputs(method->outputs_size());
status = method->get_outputs(outputs.data(), outputs.size());
ET_CHECK(status == Error::Ok);
for (EValue& output : outputs) {
// TODO(T159700776): This assumes that all outputs are fp32 tensors. Add
// support for other EValues and Tensor dtypes, and print tensors in a more
// readable way.
auto output_tensor = output.toTensor();
auto data_output = output_tensor.const_data_ptr<float>();
for (size_t j = 0; j < output_tensor.numel(); ++j) {
ET_LOG(Info, "%f", data_output[j]);
}
// Print the first and last 100 elements of long lists of scalars.
std::cout << torch::executor::util::evalue_edge_items(100);
for (int i = 0; i < outputs.size(); ++i) {
std::cout << "Output " << i << ": " << outputs[i] << std::endl;
}

// Dump the profiling data to the specified file.
Expand Down
1 change: 1 addition & 0 deletions examples/apple/mps/executor_runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def define_common_targets():
deps = [
"//executorch/backends/apple/mps/runtime:MPSBackend",
"//executorch/runtime/executor:program",
"//executorch/extension/evalue_util:print_evalue",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/kernels/portable:generated_lib_all_ops",
"//executorch/extension/data_loader:file_data_loader",
Expand Down

0 comments on commit 45d20f7

Please sign in to comment.