diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java index c24c367860..a9606f2b46 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java @@ -10,64 +10,35 @@ import android.app.Activity; import android.app.AlertDialog; -import android.content.Context; import android.os.Bundle; import android.widget.Button; import android.widget.EditText; -import android.widget.TextView; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import android.widget.ImageButton; +import android.widget.ListView; import org.pytorch.executorch.LlamaCallback; import org.pytorch.executorch.LlamaModule; public class MainActivity extends Activity implements Runnable, LlamaCallback { private EditText mEditTextMessage; - private TextView mTextViewChat; private Button mSendButton; private Button mStopButton; - private Button mModelButton; + private ImageButton mModelButton; + private ListView mMessagesView; + private MessageAdapter mMessageAdapter; private LlamaModule mModule = null; - private String mResult = null; + private Message mResultMessage = null; - private static String assetFilePath(Context context, String assetName) throws IOException { - File file = new File(context.getFilesDir(), assetName); - if (file.exists() && file.length() > 0) { - return file.getAbsolutePath(); - } - - try (InputStream is = context.getAssets().open(assetName)) { - try (OutputStream os = new FileOutputStream(file)) { - byte[] buffer = new byte[4 * 1024]; - int read; - while ((read = is.read(buffer)) != -1) { - os.write(buffer, 0, read); - } - os.flush(); - } - return file.getAbsolutePath(); - } - } + private int mNumTokens = 0; + private long mRunStartTime = 0; @Override public void onResult(String result) { System.out.println("onResult: " + result); - mResult = result; + mResultMessage.appendText(result); + mNumTokens++; run(); } - private void setModel(String modelPath, String tokenizerPath) { - try { - String model = MainActivity.assetFilePath(getApplicationContext(), modelPath); - String tokenizer = MainActivity.assetFilePath(getApplicationContext(), tokenizerPath); - mModule = new LlamaModule(model, tokenizer, 0.8f); - } catch (IOException e) { - finish(); - } - } - private void setLocalModel(String modelPath, String tokenizerPath) { mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); } @@ -82,14 +53,13 @@ private void modelDialog() { public void onClick(android.content.DialogInterface dialog, int item) { switch (item) { case 0: - setModel("stories110M.pte", "tokenizer.bin"); + setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin"); break; case 1: setLocalModel("/data/local/tmp/language.pte", "/data/local/tmp/language.bin"); break; } mEditTextMessage.setText(""); - mTextViewChat.setText(""); dialog.dismiss(); } }); @@ -103,21 +73,41 @@ protected void onCreate(Bundle savedInstanceState) { setContentView(R.layout.activity_main); mEditTextMessage = findViewById(R.id.editTextMessage); - mTextViewChat = findViewById(R.id.textViewChat); mSendButton = findViewById(R.id.sendButton); mStopButton = findViewById(R.id.stopButton); mModelButton = findViewById(R.id.modelButton); - + mMessagesView = findViewById(R.id.messages_view); + mMessageAdapter = new MessageAdapter(this, R.layout.sent_message); + mMessagesView.setAdapter(mMessageAdapter); mSendButton.setOnClickListener( view -> { String prompt = mEditTextMessage.getText().toString(); - mTextViewChat.append(prompt); + mMessageAdapter.add(new Message(prompt, true)); + mMessageAdapter.notifyDataSetChanged(); mEditTextMessage.setText(""); + mResultMessage = new Message("", false); + mMessageAdapter.add(mResultMessage); Runnable runnable = new Runnable() { @Override public void run() { + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStarted(); + } + }); + mModule.generate(prompt, MainActivity.this); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStopped(); + } + }); } }; new Thread(runnable).start(); @@ -131,10 +121,31 @@ public void run() { mModelButton.setOnClickListener( view -> { mModule.stop(); + mMessageAdapter.clear(); + mMessageAdapter.notifyDataSetChanged(); modelDialog(); }); - setModel("stories110M.pte", "tokenizer.bin"); + setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin"); + onModelRunStopped(); + } + + private void onModelRunStarted() { + mSendButton.setEnabled(false); + mStopButton.setEnabled(true); + mRunStartTime = System.currentTimeMillis(); + } + + private void onModelRunStopped() { + long runDuration = System.currentTimeMillis() - mRunStartTime; + if (mResultMessage != null) { + mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f)); + } + mSendButton.setEnabled(true); + mStopButton.setEnabled(false); + mNumTokens = 0; + mRunStartTime = 0; + mMessageAdapter.notifyDataSetChanged(); } @Override @@ -143,7 +154,7 @@ public void run() { new Runnable() { @Override public void run() { - mTextViewChat.append(mResult); + mMessageAdapter.notifyDataSetChanged(); } }); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/Message.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/Message.java new file mode 100644 index 0000000000..81b77b1aba --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/Message.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class Message { + private String text; + private boolean isSent; + private float tokensPerSecond; + + public Message(String text, boolean isSent) { + this.text = text; + this.isSent = isSent; + } + + public String getText() { + return text; + } + + public void appendText(String text) { + this.text += text; + } + + public boolean getIsSent() { + return isSent; + } + + public void setTokensPerSecond(float tokensPerSecond) { + this.tokensPerSecond = tokensPerSecond; + } + + public float getTokensPerSecond() { + return tokensPerSecond; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MessageAdapter.java new file mode 100644 index 0000000000..188bd7ed4b --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MessageAdapter.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.TextView; + +public class MessageAdapter extends ArrayAdapter { + public MessageAdapter(android.content.Context context, int resource) { + super(context, resource); + } + + @Override + public View getView(int position, View convertView, ViewGroup parent) { + Message currentMessage = getItem(position); + int layoutIdForListItem = + currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + View listItemView = + LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); + TextView messageTextView = listItemView.findViewById(R.id.message_text); + messageTextView.setText(currentMessage.getText()); + + if (currentMessage.getTokensPerSecond() > 0) { + TextView tokensView = listItemView.findViewById(R.id.tokens_per_second); + tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s"); + } + + return listItemView; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml new file mode 100644 index 0000000000..ea2d1bbfa1 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml new file mode 100644 index 0000000000..e8d13ca4e1 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml new file mode 100644 index 0000000000..afbe22da80 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml @@ -0,0 +1,5 @@ + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml index f769578d33..976a64b6ce 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -1,28 +1,40 @@ - - - + -