Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama.swiftui: Fix a small bug #8268

Merged
merged 4 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ actor LlamaContext {
private var context: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
var latest_llama_token_is_eog_or_reach_len: Bool = false

/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
Expand Down Expand Up @@ -160,6 +161,7 @@ actor LlamaContext {

if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
latest_llama_token_is_eog_or_reach_len = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
return new_token_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class LlamaState: ObservableObject {
messageLog += "\(text)"

Task.detached {
while await llamaContext.n_cur < llamaContext.n_len {
while await !llamaContext.latest_llama_token_is_eog_or_reach_len{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand why this issue would only happen in Swift, and not in the other parallel projects, and looking at those, it looks like the paradigm is usually checking n_cur <= n_len -- but the Swift library (prior to your change) checks n_cur < n_len.

I'm not able to run this Swift library very easily -- could you please try changing this line:

Suggested change
while await !llamaContext.latest_llama_token_is_eog_or_reach_len{
while await llamaContext.n_cur <= llamaContext.n_len {

If you do that, do you still see the problem?

The main reason I ask is that -- as much as possible -- I'd like to keep the logic for all of the various libraries more or less parallel to each other, and I'm leery of one-off solutions that are only implemented in one platform at a time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If you do that, do you still see the problem?" --- You changed back to the bug.
The problem is that, the original code for swiftUI frontend haven't checked the circumstances when there is EOT/EOS, so frontend keeps printing blank lines.
In the desired outcome, When there is EOT/EOS, frontend needs to stop generating any blank lines.
So, my solution is, besides checking if reaching n_len, you also need to check if the latest token is EOT/EOS, that is the whole meaning of why I added variable "latest_llama_token_is_eog_or_reach_len".

FYI, I haven't looked each line of original LLAMA code, but I did my best to use fewest line to fix this stupid bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because current n_len is just 64, you are not going to see the problem. If you change the n_len bigger, like 1024, you might find it keeps generating blank lines even after EOT/EOS

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You changed back to the bug.

The original was <, I changed it to <= -- is there any difference in behavior if you do that?

Copy link
Collaborator

@HanClinto HanClinto Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I haven't looked each line of original LLAMA code, but I did my best to use fewest line to fix this stupid bug.

And yeah, don't feel bad -- it's too large of a project for anyone to have done that. I hope I'm not frustrating you -- I haven't looked at each line either. That's largely why I'm asking this. I'm not familiar with the Swift code, and am mainly trying to keep the various libraries from diverging too much. If a different portion of the library already solved this same problem in a different way, then in so far as we can keep the solutions more parallel, we should do that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean about needing EOT/EOS detection in addition to correct n_len handling.

What I haven't yet tracked down is -- how do the other portions of the code deal with EOT/EOS detection, and how can we keep the Swift library more parallel with those?

We don't need to read every line, but understanding how at least one other section of the code does it would be good. And if no other place does it, then maybe your fix should be replicated elsewhere.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the check in simple example:

// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
LOG_TEE("\n");
break;
}

The proposed change is good, but just change the name of the variable to is_done

let result = await llamaContext.completion_loop()
await MainActor.run {
self.messageLog += "\(result)"
Expand Down