Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish UX and display t/s #2286

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,35 @@

import android.app.Activity;
import android.app.AlertDialog;
import android.content.Context;
import android.os.Bundle;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import android.widget.ImageButton;
import android.widget.ListView;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;

public class MainActivity extends Activity implements Runnable, LlamaCallback {
private EditText mEditTextMessage;
private TextView mTextViewChat;
private Button mSendButton;
private Button mStopButton;
private Button mModelButton;
private ImageButton mModelButton;
private ListView mMessagesView;
private MessageAdapter mMessageAdapter;
private LlamaModule mModule = null;
private String mResult = null;
private Message mResultMessage = null;

private static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
private int mNumTokens = 0;
private long mRunStartTime = 0;

@Override
public void onResult(String result) {
System.out.println("onResult: " + result);
mResult = result;
mResultMessage.appendText(result);
mNumTokens++;
run();
}

private void setModel(String modelPath, String tokenizerPath) {
try {
String model = MainActivity.assetFilePath(getApplicationContext(), modelPath);
String tokenizer = MainActivity.assetFilePath(getApplicationContext(), tokenizerPath);
mModule = new LlamaModule(model, tokenizer, 0.8f);
} catch (IOException e) {
finish();
}
}

private void setLocalModel(String modelPath, String tokenizerPath) {
mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
}
Expand All @@ -82,14 +53,13 @@ private void modelDialog() {
public void onClick(android.content.DialogInterface dialog, int item) {
switch (item) {
case 0:
setModel("stories110M.pte", "tokenizer.bin");
setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin");
break;
case 1:
setLocalModel("/data/local/tmp/language.pte", "/data/local/tmp/language.bin");
break;
}
mEditTextMessage.setText("");
mTextViewChat.setText("");
dialog.dismiss();
}
});
Expand All @@ -103,21 +73,41 @@ protected void onCreate(Bundle savedInstanceState) {
setContentView(R.layout.activity_main);

mEditTextMessage = findViewById(R.id.editTextMessage);
mTextViewChat = findViewById(R.id.textViewChat);
mSendButton = findViewById(R.id.sendButton);
mStopButton = findViewById(R.id.stopButton);
mModelButton = findViewById(R.id.modelButton);

mMessagesView = findViewById(R.id.messages_view);
mMessageAdapter = new MessageAdapter(this, R.layout.sent_message);
mMessagesView.setAdapter(mMessageAdapter);
mSendButton.setOnClickListener(
view -> {
String prompt = mEditTextMessage.getText().toString();
mTextViewChat.append(prompt);
mMessageAdapter.add(new Message(prompt, true));
mMessageAdapter.notifyDataSetChanged();
mEditTextMessage.setText("");
mResultMessage = new Message("", false);
mMessageAdapter.add(mResultMessage);
Runnable runnable =
new Runnable() {
@Override
public void run() {
runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStarted();
}
});

mModule.generate(prompt, MainActivity.this);

runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStopped();
}
});
}
};
new Thread(runnable).start();
Expand All @@ -131,10 +121,31 @@ public void run() {
mModelButton.setOnClickListener(
view -> {
mModule.stop();
mMessageAdapter.clear();
mMessageAdapter.notifyDataSetChanged();
modelDialog();
});

setModel("stories110M.pte", "tokenizer.bin");
setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin");
onModelRunStopped();
}

private void onModelRunStarted() {
mSendButton.setEnabled(false);
mStopButton.setEnabled(true);
mRunStartTime = System.currentTimeMillis();
}

private void onModelRunStopped() {
long runDuration = System.currentTimeMillis() - mRunStartTime;
if (mResultMessage != null) {
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
}
mSendButton.setEnabled(true);
mStopButton.setEnabled(false);
mNumTokens = 0;
mRunStartTime = 0;
mMessageAdapter.notifyDataSetChanged();
}

@Override
Expand All @@ -143,7 +154,7 @@ public void run() {
new Runnable() {
@Override
public void run() {
mTextViewChat.append(mResult);
mMessageAdapter.notifyDataSetChanged();
}
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.example.executorchllamademo;

public class Message {
private String text;
private boolean isSent;
private float tokensPerSecond;

public Message(String text, boolean isSent) {
this.text = text;
this.isSent = isSent;
}

public String getText() {
return text;
}

public void appendText(String text) {
this.text += text;
}

public boolean getIsSent() {
return isSent;
}

public void setTokensPerSecond(float tokensPerSecond) {
this.tokensPerSecond = tokensPerSecond;
}

public float getTokensPerSecond() {
return tokensPerSecond;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.example.executorchllamademo;

import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.ArrayAdapter;
import android.widget.TextView;

public class MessageAdapter extends ArrayAdapter<Message> {
public MessageAdapter(android.content.Context context, int resource) {
super(context, resource);
}

@Override
public View getView(int position, View convertView, ViewGroup parent) {
Message currentMessage = getItem(position);
int layoutIdForListItem =
currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message;
View listItemView =
LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
TextView messageTextView = listItemView.findViewById(R.id.message_text);
messageTextView.setText(currentMessage.getText());

if (currentMessage.getTokensPerSecond() > 0) {
TextView tokensView = listItemView.findViewById(R.id.tokens_per_second);
tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s");
}

return listItemView;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<shape xmlns:android="http://schemas.android.com/apk/res/android"
android:shape="rectangle">
<solid android:color="#fff" />
<corners android:radius="10dp" />
</shape>
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<shape xmlns:android="http://schemas.android.com/apk/res/android"
android:shape="rectangle">
<solid android:color="@color/colorPrimary" />
<corners android:radius="10dp" />
</shape>
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<vector android:height="24dp" android:tint="#000000"
android:viewportHeight="24" android:viewportWidth="24"
android:width="24dp" xmlns:android="http://schemas.android.com/apk/res/android">
<path android:fillColor="@android:color/white" android:pathData="M6,10c-1.1,0 -2,0.9 -2,2s0.9,2 2,2 2,-0.9 2,-2 -0.9,-2 -2,-2zM18,10c-1.1,0 -2,0.9 -2,2s0.9,2 2,2 2,-0.9 2,-2 -0.9,-2 -2,-2zM12,10c-1.1,0 -2,0.9 -2,2s0.9,2 2,2 2,-0.9 2,-2 -0.9,-2 -2,-2z"/>
</vector>
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:clipToPadding="false"
android:focusableInTouchMode="true"
tools:context=".MainActivity">
<EditText
android:id="@+id/editTextMessage"

<ListView
android:layout_width="match_parent"
android:id="@+id/messages_view"
android:layout_weight="2"
android:divider="#fff"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:textSize="20sp"
android:hint="Type a prompt" />
<LinearLayout
/>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:orientation="horizontal"
android:gravity="right"
tools:ignore="RtlHardcoded">
<Button
android:background="#fff"
android:orientation="horizontal">
<ImageButton
android:id="@+id/modelButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Model..." />
android:src="@drawable/three_dots" />
<EditText
android:id="@+id/editTextMessage"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_weight="2"
android:ems="10"
android:hint="Prompt"
android:inputType="text"
android:paddingHorizontal="10dp"
android:text="" />
<Button
android:id="@+id/stopButton"
android:layout_width="wrap_content"
Expand All @@ -34,11 +46,4 @@
android:layout_height="wrap_content"
android:text="Generate" />
</LinearLayout>
<TextView
android:id="@+id/textViewChat"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_above="@id/editTextMessage"
android:textSize="24sp"
android:scrollbars="vertical" />
</RelativeLayout>
</LinearLayout>
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:paddingVertical="10dp"
android:paddingLeft="15dp"
android:paddingRight="60dp"
android:clipToPadding="false">

<TextView
android:id="@+id/name"
android:layout_marginLeft="15dp"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:paddingBottom="4dp"
android:text="LLaMA"/>

<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:id="@+id/message_text"
android:layout_below="@+id/name"
android:layout_alignLeft="@+id/name"
android:background="@drawable/received_message"
android:paddingVertical="12dp"
android:paddingHorizontal="16dp"
android:elevation="2dp"
android:textSize="18dp"
android:text="Generated text"
/>


<TextView
android:id="@+id/tokens_per_second"
android:layout_marginLeft="15dp"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_below="@+id/message_text"
android:paddingBottom="4dp"
android:text=""/>

</RelativeLayout>
Loading
Loading