Skip to content

Commit

Permalink
Polish UX and display t/s (pytorch#2286)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2286

Reviewed By: shoumikhin

Differential Revision: D54612043
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Mar 7, 2024
1 parent 59bd01a commit 0bb842e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@
public class MainActivity extends Activity implements Runnable, LlamaCallback {
private EditText mEditTextMessage;
private Button mSendButton;
private Button mStopButton;
private ImageButton mModelButton;
private ListView mMessagesView;
private MessageAdapter mMessageAdapter;
private LlamaModule mModule = null;
private Message mResultMessage = null;

private int mNumTokens = 0;
private long mRunStartTime = 0;

@Override
public void onResult(String result) {
System.out.println("onResult: " + result);
mResultMessage.appendText(result);
mNumTokens++;
run();
}

Expand Down Expand Up @@ -70,11 +73,38 @@ protected void onCreate(Bundle savedInstanceState) {

mEditTextMessage = findViewById(R.id.editTextMessage);
mSendButton = findViewById(R.id.sendButton);
mStopButton = findViewById(R.id.stopButton);
mModelButton = findViewById(R.id.modelButton);
mMessagesView = findViewById(R.id.messages_view);
mMessageAdapter = new MessageAdapter(this, R.layout.sent_message);
mMessagesView.setAdapter(mMessageAdapter);
mModelButton.setOnClickListener(
view -> {
mModule.stop();
mMessageAdapter.clear();
mMessageAdapter.notifyDataSetChanged();
modelDialog();
});

setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin");
onModelRunStopped();
}

private void onModelRunStarted() {
mSendButton.setText("Stop");
mSendButton.setOnClickListener(
view -> {
mModule.stop();
});

mRunStartTime = System.currentTimeMillis();
}

private void onModelRunStopped() {
long runDuration = System.currentTimeMillis() - mRunStartTime;
if (mResultMessage != null) {
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
}
mSendButton.setText("Generate");
mSendButton.setOnClickListener(
view -> {
String prompt = mEditTextMessage.getText().toString();
Expand All @@ -87,26 +117,30 @@ protected void onCreate(Bundle savedInstanceState) {
new Runnable() {
@Override
public void run() {
runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStarted();
}
});

mModule.generate(prompt, MainActivity.this);

runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStopped();
}
});
}
};
new Thread(runnable).start();
});

mStopButton.setOnClickListener(
view -> {
mModule.stop();
});

mModelButton.setOnClickListener(
view -> {
mModule.stop();
mMessageAdapter.clear();
mMessageAdapter.notifyDataSetChanged();
modelDialog();
});

setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin");
mNumTokens = 0;
mRunStartTime = 0;
mMessageAdapter.notifyDataSetChanged();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
public class Message {
private String text;
private boolean isSent;
private float tokensPerSecond;

public Message(String text, boolean isSent) {
this.text = text;
Expand All @@ -28,4 +29,12 @@ public void appendText(String text) {
public boolean getIsSent() {
return isSent;
}

public void setTokensPerSecond(float tokensPerSecond) {
this.tokensPerSecond = tokensPerSecond;
}

public float getTokensPerSecond() {
return tokensPerSecond;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public View getView(int position, View convertView, ViewGroup parent) {
LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
TextView messageTextView = listItemView.findViewById(R.id.message_text);
messageTextView.setText(currentMessage.getText());

if (currentMessage.getTokensPerSecond() > 0) {
TextView tokensView = listItemView.findViewById(R.id.tokens_per_second);
tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s");
}

return listItemView;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@
android:inputType="text"
android:paddingHorizontal="10dp"
android:text="" />
<Button
android:id="@+id/stopButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Stop" />
<Button
android:id="@+id/sendButton"
android:layout_width="wrap_content"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,15 @@
android:textSize="18dp"
android:text="Generated text"
/>


<TextView
android:id="@+id/tokens_per_second"
android:layout_marginLeft="15dp"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_below="@+id/message_text"
android:paddingBottom="4dp"
android:text=""/>

</RelativeLayout>

0 comments on commit 0bb842e

Please sign in to comment.