Skip to content

Commit

Permalink
Polish UX and display t/s (#2286)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2286

Differential Revision: D54612043
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Mar 7, 2024
1 parent cf929d0 commit 0a95d0c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ public class MainActivity extends Activity implements Runnable, LlamaCallback {
private LlamaModule mModule = null;
private Message mResultMessage = null;

private int mNumTokens = 0;
private long mRunStartTime = 0;

@Override
public void onResult(String result) {
System.out.println("onResult: " + result);
mResultMessage.appendText(result);
mNumTokens++;
run();
}

Expand Down Expand Up @@ -87,7 +91,23 @@ protected void onCreate(Bundle savedInstanceState) {
new Runnable() {
@Override
public void run() {
runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStarted();
}
});

mModule.generate(prompt, MainActivity.this);

runOnUiThread(
new Runnable() {
@Override
public void run() {
onModelRunStopped();
}
});
}
};
new Thread(runnable).start();
Expand All @@ -107,6 +127,25 @@ public void run() {
});

setLocalModel("/data/local/tmp/stories110M.pte", "/data/local/tmp/tokenizer.bin");
onModelRunStopped();
}

private void onModelRunStarted() {
mSendButton.setEnabled(false);
mStopButton.setEnabled(true);
mRunStartTime = System.currentTimeMillis();
}

private void onModelRunStopped() {
long runDuration = System.currentTimeMillis() - mRunStartTime;
if (mResultMessage != null) {
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
}
mSendButton.setEnabled(true);
mStopButton.setEnabled(false);
mNumTokens = 0;
mRunStartTime = 0;
mMessageAdapter.notifyDataSetChanged();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
public class Message {
private String text;
private boolean isSent;
private float tokensPerSecond;

public Message(String text, boolean isSent) {
this.text = text;
Expand All @@ -28,4 +29,12 @@ public void appendText(String text) {
public boolean getIsSent() {
return isSent;
}

public void setTokensPerSecond(float tokensPerSecond) {
this.tokensPerSecond = tokensPerSecond;
}

public float getTokensPerSecond() {
return tokensPerSecond;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public View getView(int position, View convertView, ViewGroup parent) {
LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
TextView messageTextView = listItemView.findViewById(R.id.message_text);
messageTextView.setText(currentMessage.getText());

if (currentMessage.getTokensPerSecond() > 0) {
TextView tokensView = listItemView.findViewById(R.id.tokens_per_second);
tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s");
}

return listItemView;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,15 @@
android:textSize="18dp"
android:text="Generated text"
/>


<TextView
android:id="@+id/tokens_per_second"
android:layout_marginLeft="15dp"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_below="@+id/message_text"
android:paddingBottom="4dp"
android:text=""/>

</RelativeLayout>

0 comments on commit 0a95d0c

Please sign in to comment.