Skip to content

Commit

Permalink
feat(#23): support llamacpp api keys
Browse files Browse the repository at this point in the history
  • Loading branch information
mcharytoniuk committed Oct 19, 2024
1 parent 43636b4 commit c654b03
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 163 deletions.
98 changes: 4 additions & 94 deletions llamacpp/LlamaCppClient.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package llamacpp

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
)

Expand All @@ -23,97 +20,6 @@ type LlamaCppClient struct {
LlamaCppConfiguration *LlamaCppConfiguration
}

func (self *LlamaCppClient) GenerateCompletion(
ctx context.Context,
responseChannel chan LlamaCppCompletionToken,
llamaCppCompletionRequest LlamaCppCompletionRequest,
) {
defer close(responseChannel)

body, err := json.Marshal(llamaCppCompletionRequest)

if err != nil {
responseChannel <- LlamaCppCompletionToken{
Error: err,
}

return
}

request, err := http.NewRequestWithContext(
ctx,
"POST",
self.LlamaCppConfiguration.HttpAddress.BuildUrlWithPath("completion").String(),
bytes.NewBuffer(body),
)

if err != nil {
responseChannel <- LlamaCppCompletionToken{
Error: err,
}

return
}

response, err := self.HttpClient.Do(request)

if err != nil {
responseChannel <- LlamaCppCompletionToken{
Error: err,
}

return
}

defer response.Body.Close()

if http.StatusOK != response.StatusCode {
responseChannel <- LlamaCppCompletionToken{
Error: ErrorNon200Response,
}

return
}

reader := bufio.NewReader(response.Body)

for {
line, err := reader.ReadBytes('\n')

if err != nil && err != io.EOF {
responseChannel <- LlamaCppCompletionToken{
Error: err,
}

return
}

var llamaCppCompletionToken LlamaCppCompletionToken

trimmedLine := bytes.TrimPrefix(line, []byte(CompletionDataPrefix))

if len(trimmedLine) < 2 {
continue
}

err = json.Unmarshal(trimmedLine, &llamaCppCompletionToken)

if err != nil {
responseChannel <- LlamaCppCompletionToken{
Error: err,
}

return
}

responseChannel <- llamaCppCompletionToken

if llamaCppCompletionToken.IsLast {
return
}
}
}

func (self *LlamaCppClient) GetHealth(
ctx context.Context,
responseChannel chan<- LlamaCppHealthStatus,
Expand Down Expand Up @@ -199,6 +105,10 @@ func (self *LlamaCppClient) GetSlots(
return
}

if self.LlamaCppConfiguration.ApiKey != "" {
request.Header.Set("Authorization", "Bearer "+self.LlamaCppConfiguration.ApiKey)
}

response, err := self.HttpClient.Do(request)

if err != nil {
Expand Down
61 changes: 0 additions & 61 deletions llamacpp/LlamaCppClient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,6 @@ func TestHealthIsObtained(t *testing.T) {
assert.Nil(t, healthStatus.Error)
}

func TestCompletionsAreGenerated(t *testing.T) {
responseChannel := make(chan LlamaCppCompletionToken)

go llamaCppClient.GenerateCompletion(
context.Background(),
responseChannel,
LlamaCppCompletionRequest{
NPredict: 3,
Prompt: "Who are you?",
Stream: true,
},
)

var generatedTokens int = 0

for token := range responseChannel {
if token.Error != nil {
t.Fatal(token.Error)
} else {
generatedTokens += 1
}
}

// 3 tokens + 1 summary token
assert.Equal(t, 4, generatedTokens)
}

func TestSlotsAreObtained(t *testing.T) {
// the test assumes llama.cpp instance running with 4 available slots
// all of them idle
Expand Down Expand Up @@ -99,37 +72,3 @@ func TestSlotsAggregatedStatusIsbtained(t *testing.T) {
assert.Equal(t, slotsAggregatedStatus.SlotsIdle, 4)
assert.Equal(t, slotsAggregatedStatus.SlotsProcessing, 0)
}

func TestJsonSchemaConstrainedCompletionsAreGenerated(t *testing.T) {
responseChannel := make(chan LlamaCppCompletionToken)

go llamaCppClient.GenerateCompletion(
context.Background(),
responseChannel,
LlamaCppCompletionRequest{
JsonSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"hello": map[string]string{
"type": "string",
},
},
},
NPredict: 100,
Prompt: "Say 'world' as a hello!",
Stream: true,
},
)

acc := ""

for token := range responseChannel {
if token.Error != nil {
t.Fatal(token.Error)
} else {
acc += token.Content
}
}

assert.Equal(t, "{ \"hello\": \"world\" } ", acc)
}
8 changes: 0 additions & 8 deletions llamacpp/LlamaCppCompletionToken.go

This file was deleted.

1 change: 1 addition & 0 deletions llamacpp/LlamaCppConfiguration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

type LlamaCppConfiguration struct {
HttpAddress *netcfg.HttpAddressConfiguration `json:"http_address"`
ApiKey string
}

func (self *LlamaCppConfiguration) String() string {
Expand Down
5 changes: 5 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func main() {
Value: "http",
Destination: &agent.ExternalLlamaCppConfiguration.HttpAddress.Scheme,
},
&cli.StringFlag{
Name: "local-llamacpp-api-key",
Value: "",
Destination: &agent.LocalLlamaCppConfiguration.ApiKey,
},
&cli.StringFlag{
Name: "local-llamacpp-host",
Value: "127.0.0.1",
Expand Down

0 comments on commit c654b03

Please sign in to comment.