From ca889fb59b06ec8bf07d4c0c56fed2b59d0d0a37 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 10 Sep 2024 17:15:57 -0700 Subject: [PATCH] Minibench use model_dir instead (#5250) Summary: We specify a model dir, not model path. It's easier to update test spec Pull Request resolved: https://github.com/pytorch/executorch/pull/5250 Reviewed By: huydhn Differential Revision: D62473641 Pulled By: kirklandsign fbshipit-source-id: 40864831de9960fe29b101683ef7182e2f56fe7b --- .../org/pytorch/minibench/BenchmarkActivity.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index e9599dd351..a79f668f80 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -11,8 +11,10 @@ import android.app.Activity; import android.content.Intent; import android.os.Bundle; +import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.util.Arrays; import org.pytorch.executorch.Module; public class BenchmarkActivity extends Activity { @@ -20,13 +22,19 @@ public class BenchmarkActivity extends Activity { protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); Intent intent = getIntent(); - String modelPath = intent.getStringExtra("model_path"); + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + int numIter = intent.getIntExtra("num_iter", 10); // TODO: Format the string with a parsable format StringBuilder resultText = new StringBuilder(); - Module module = Module.load(modelPath); + Module module = Module.load(model.getPath()); for (int i = 0; i < numIter; i++) { long start = System.currentTimeMillis(); module.forward();