Skip to content

Commit

Permalink
add perplexity prompter
Browse files Browse the repository at this point in the history
  • Loading branch information
Southclaws committed Dec 29, 2024
1 parent 5d5faf0 commit 31d5838
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 0 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/r3labs/sse/v2 v2.10.0 // indirect
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
Expand All @@ -149,6 +150,7 @@ require (
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/tmaxmax/go-sse v0.10.0 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/vanng822/css v1.0.1 // indirect
github.com/vanng822/go-premailer v1.21.0 // indirect
Expand All @@ -165,6 +167,7 @@ require (
golang.org/x/tools v0.25.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f // indirect
google.golang.org/grpc v1.67.1 // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
modernc.org/libc v1.61.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
Expand Down
7 changes: 7 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:Om
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0=
github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I=
github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw=
github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o=
github.com/redis/rueidis v1.0.49 h1:uhjMcQ663R8st3saoo85VV9Ce37zfvRXiveZcBrS3YQ=
Expand Down Expand Up @@ -494,6 +496,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA=
github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8=
github.com/twilio/twilio-go v1.23.5 h1:5ksHynnYhjKf1vG7KK7+jujEj/DhQ1knwQAhNuDExW4=
github.com/twilio/twilio-go v1.23.5/go.mod h1:zRkMjudW7v7MqQ3cWNZmSoZJ7EBjPZ4OpNh2zm7Q6ko=
github.com/twilio/twilio-go v1.23.6 h1:9gjIZ8w3MN+8ifPZgK74vF3CLfnJ6ytMNqOI2r2ipLs=
Expand Down Expand Up @@ -596,6 +600,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
Expand Down Expand Up @@ -743,6 +748,8 @@ google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFyt
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y=
gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Config struct {

LanguageModelProvider string `envconfig:"LANGUAGE_MODEL_PROVIDER"`
OpenAIKey string `envconfig:"OPENAI_API_KEY"`
PerplexityAPIKey string `envconfig:"PERPLEXITY_API_KEY"`

SemdexProvider string `envconfig:"SEMDEX_PROVIDER" default:""`
WeaviateURL string `envconfig:"WEAVIATE_URL"`
Expand Down
3 changes: 3 additions & 0 deletions internal/infrastructure/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ func New(cfg config.Config) (Prompter, error) {
case "openai":
return newOpenAI(cfg)

case "perplexity":
return newPerplexity(cfg)

case "mock":
return newMock()

Expand Down
208 changes: 208 additions & 0 deletions internal/infrastructure/ai/perplexity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package ai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/Southclaws/storyden/internal/config"
"github.com/openai/openai-go/packages/ssestream"
)

const (
DefaultEndpoint = "https://api.perplexity.ai/chat/completions"
DefautTimeout = 10 * time.Second
)

const (
Llama_3_1SonarSmall_128kChat = "llama-3.1-sonar-small-128k-chat"
Llama_3_1SonarLarge_128kChat = "llama-3.1-sonar-large-128k-chat"
Llama_3_1SonarSmall_128kOnline = "llama-3.1-sonar-small-128k-online"
Llama_3_1SonarLarge_128kOnline = "llama-3.1-sonar-large-128k-online"
Llama_3_1_8bInstruct = "llama-3.1-8b-instruct"
Llama_3_1_70bInstruct = "llama-3.1-70b-instruct"
)

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type CompletionRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Message Message `json:"message"`
Delta Message `json:"delta"`
}

type CompletionResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Created int `json:"created"`
Usage Usage `json:"usage"`
Citations []string `json:"citations"`
Object string `json:"object"`
Choices []Choice `json:"choices"`
}

type Perplexity struct {
endpoint string
apiKey string
model string
httpClient *http.Client
httpTimeout time.Duration
}

func newPerplexity(cfg config.Config) (*Perplexity, error) {
s := &Perplexity{
apiKey: cfg.PerplexityAPIKey,
endpoint: DefaultEndpoint,
model: Llama_3_1SonarSmall_128kChat,
httpClient: &http.Client{},
httpTimeout: DefautTimeout,
}
return s, nil
}

func (s *Perplexity) Prompt(ctx context.Context, input string) (*Result, error) {
r := &CompletionResponse{}

reqBody := CompletionRequest{
Messages: []Message{{Role: "user", Content: input}},
Model: s.model,
}

requestBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}

ctx, cancel := context.WithDeadline(ctx, time.Now().Add(s.httpTimeout))
defer cancel()

req, err := http.NewRequestWithContext(ctx, "POST", s.endpoint, bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")

resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized {
return nil, fmt.Errorf("unauthorized: check your API key")
}
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

err = json.Unmarshal(body, r)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response body: %w - body response=%s", err, string(body))
}

return &Result{
Answer: r.Choices[0].Message.Content,
}, nil
}

func (s *Perplexity) PromptStream(ctx context.Context, input string) (chan string, chan error) {
outch := make(chan string)
errch := make(chan error)

resp, err := func() (*http.Response, error) {
reqBody := CompletionRequest{
Stream: true,
Messages: []Message{{Role: "user", Content: input}},
Model: s.model,
}

requestBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", s.endpoint, bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")

resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

return resp, nil
}()
if err != nil {
errch <- err
return outch, errch
}

dec := ssestream.NewDecoder(resp)

go func() {
defer resp.Body.Close()
defer close(outch)
defer close(errch)

for dec.Next() {
event := dec.Event()
var cr CompletionResponse

if err := json.Unmarshal(event.Data, &cr); err != nil {
errch <- fmt.Errorf("failed to unmarshal SSE event: %w", err)
return
}

outch <- cr.Choices[0].Delta.Content

if cr.Choices[0].FinishReason == "stop" {
break
}
}

if dec.Err() != nil {
errch <- fmt.Errorf("failed to read SSE stream: %w", dec.Err())
}
}()

return outch, errch
}

func replaceCitations(message string, citations []string) string {
return message
}

func (o *Perplexity) EmbeddingFunc() func(ctx context.Context, text string) ([]float32, error) {
panic("not implemented")
}

0 comments on commit 31d5838

Please sign in to comment.