From 090b7f8a71b588747ec8d7f95f335f9a05b5f513 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 00:58:53 +0200 Subject: [PATCH 01/15] tools --- examples/llm/openai/thread/main.go | 24 ++++++++++++------------ go.mod | 1 + go.sum | 2 ++ llm/openai/function.go | 19 +++++++++++++++++++ 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/examples/llm/openai/thread/main.go b/examples/llm/openai/thread/main.go index 5180349c..695f4714 100644 --- a/examples/llm/openai/thread/main.go +++ b/examples/llm/openai/thread/main.go @@ -2,11 +2,12 @@ package main import ( "context" + "encoding/json" "fmt" - "strings" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/tools/dalle" "github.com/henomis/lingoose/transformer" ) @@ -32,15 +33,7 @@ func newStr(str string) *string { func main() { openaillm := openai.New().WithModel(openai.GPT4o) - openaillm.WithToolChoice(newStr("auto")) - err := openaillm.BindFunction( - crateImage, - "createImage", - "use this function to create an image from a description", - ) - if err != nil { - panic(err) - } + openaillm.WithToolChoice(newStr("auto")).WithTools(dalle.New()) t := thread.New().AddMessage( thread.NewUserMessage().AddContent( @@ -48,15 +41,22 @@ func main() { ), ) - err = openaillm.Generate(context.Background(), t) + err := openaillm.Generate(context.Background(), t) if err != nil { panic(err) } if t.LastMessage().Role == thread.RoleTool { + var output dalle.Output + + err = json.Unmarshal([]byte(t.LastMessage().Contents[0].AsToolResponseData().Result), &output) + if err != nil { + panic(err) + } + t.AddMessage(thread.NewUserMessage().AddContent( thread.NewImageContentFromURL( - strings.ReplaceAll(t.LastMessage().Contents[0].AsToolResponseData().Result, `"`, ""), + output.ImageURL, ), ).AddContent( thread.NewTextContent("can you describe the image?"), diff --git a/go.mod b/go.mod index 25c30a82..4131f84e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/henomis/restclientgo v1.2.0 github.com/invopop/jsonschema v0.7.0 github.com/sashabaranov/go-openai v1.24.0 + golang.org/x/net v0.25.0 ) require ( diff --git a/go.sum b/go.sum index e2d3391c..a79137ce 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/llm/openai/function.go b/llm/openai/function.go index 4b9c77bb..9c5953bd 100644 --- a/llm/openai/function.go +++ b/llm/openai/function.go @@ -79,6 +79,25 @@ func (o *OpenAI) BindFunction( return nil } +type Tool interface { + Description() string + Name() string + Fn() any +} + +func (o OpenAI) WithTools(tools ...Tool) OpenAI { + for _, tool := range tools { + function, err := bindFunction(tool.Fn(), tool.Name(), tool.Description()) + if err != nil { + fmt.Println(err) + } + + o.functions[tool.Name()] = *function + } + + return o +} + func (o *Legacy) getFunctions() []openai.FunctionDefinition { var functions []openai.FunctionDefinition From 4dad3cad05854b3b4e2157e18a2cfa28f9fa2716 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 01:01:14 +0200 Subject: [PATCH 02/15] chore: add tools --- examples/llm/openai/tools/python/main.go | 32 +++++ examples/tools/duckduckgo/main.go | 15 ++ examples/tools/python/main.go | 16 +++ tools/dalle/dalle.go | 50 +++++++ tools/duckduckgo/api.go | 168 +++++++++++++++++++++++ tools/duckduckgo/duckduckgo.go | 77 +++++++++++ tools/python/python.go | 74 ++++++++++ 7 files changed, 432 insertions(+) create mode 100644 examples/llm/openai/tools/python/main.go create mode 100644 examples/tools/duckduckgo/main.go create mode 100644 examples/tools/python/main.go create mode 100644 tools/dalle/dalle.go create mode 100644 tools/duckduckgo/api.go create mode 100644 tools/duckduckgo/duckduckgo.go create mode 100644 tools/python/python.go diff --git a/examples/llm/openai/tools/python/main.go b/examples/llm/openai/tools/python/main.go new file mode 100644 index 00000000..2d4bd7c0 --- /dev/null +++ b/examples/llm/openai/tools/python/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/tools/python" +) + +func main() { + newStr := func(str string) *string { + return &str + } + llm := openai.New().WithModel(openai.GPT3Dot5Turbo0613).WithToolChoice(newStr("auto")).WithTools( + python.New(), + ) + + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("calculate reverse string of 'ailatiditalia', don't try to guess, let's use appropriate tools"), + ), + ) + + llm.Generate(context.Background(), t) + if t.LastMessage().Role == thread.RoleTool { + llm.Generate(context.Background(), t) + } + + fmt.Println(t) +} diff --git a/examples/tools/duckduckgo/main.go b/examples/tools/duckduckgo/main.go new file mode 100644 index 00000000..7be5f65c --- /dev/null +++ b/examples/tools/duckduckgo/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/duckduckgo" +) + +func main() { + + t := duckduckgo.New().WithMaxResults(5) + f := t.Fn().(duckduckgo.FnPrototype) + + fmt.Println(f(duckduckgo.Input{Query: "Simone Vellei"})) +} diff --git a/examples/tools/python/main.go b/examples/tools/python/main.go new file mode 100644 index 00000000..57052fb6 --- /dev/null +++ b/examples/tools/python/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/python" +) + +func main() { + t := python.New().WithPythonPath("python3") + + pythonScript := `print("Hello from Python!")` + f := t.Fn().(python.FnPrototype) + + fmt.Println(f(python.Input{PythonCode: pythonScript})) +} diff --git a/tools/dalle/dalle.go b/tools/dalle/dalle.go new file mode 100644 index 00000000..0508d418 --- /dev/null +++ b/tools/dalle/dalle.go @@ -0,0 +1,50 @@ +package dalle + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/transformer" +) + +type Tool struct { +} + +type Input struct { + Description string `json:"description" jsonschema:"description=the description of the image that should be created"` +} + +type Output struct { + Error string `json:"error,omitempty"` + ImageURL string `json:"imageURL,omitempty"` +} + +type FnPrototype func(Input) Output + +func New() *Tool { + return &Tool{} +} + +func (t *Tool) Name() string { + return "dalle" +} + +func (t *Tool) Description() string { + return "A tool that creates an image from a description." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512) + imageURL, err := d.Transform(context.Background(), i.Description) + if err != nil { + return Output{Error: fmt.Sprintf("error creating image: %v", err)} + } + + fmt.Println("Image created with url:", imageURL) + + return Output{ImageURL: imageURL.(string)} +} diff --git a/tools/duckduckgo/api.go b/tools/duckduckgo/api.go new file mode 100644 index 00000000..a301c5e7 --- /dev/null +++ b/tools/duckduckgo/api.go @@ -0,0 +1,168 @@ +package duckduckgo + +import ( + "bytes" + "io" + "regexp" + "strings" + + "github.com/henomis/restclientgo" + "golang.org/x/net/html" +) + +const ( + class = "class" +) + +type request struct { + Query string +} + +type response struct { + MaxResults uint + HTTPStatusCode int + RawBody []byte + Results []result +} + +type result struct { + Title string + Info string + URL string +} + +func (r *request) Path() (string, error) { + return "/html/?q=" + r.Query, nil +} + +func (r *request) Encode() (io.Reader, error) { + return nil, nil +} + +func (r *request) ContentType() string { + return "" +} + +func (r *response) Decode(body io.Reader) error { + results, err := r.parseBody(body) + if err != nil { + return err + } + + r.Results = results + return nil +} + +func (r *response) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) + return nil +} + +func (r *response) AcceptContentType() string { + return "text/html" +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } + +func (r *response) parseBody(body io.Reader) ([]result, error) { + doc, err := html.Parse(body) + if err != nil { + return nil, err + } + ch := make(chan result) + go r.findWebResults(ch, doc) + + results := []result{} + for n := range ch { + results = append(results, n) + } + + return results, nil +} + +func (r *response) findWebResults(ch chan result, doc *html.Node) { + var results uint + var f func(*html.Node) + f = func(n *html.Node) { + if results >= r.MaxResults { + return + } + if n.Type == html.ElementNode && n.Data == "div" { + for _, div := range n.Attr { + if div.Key == class && strings.Contains(div.Val, "web-result") { + info, href := r.findInfo(n) + ch <- result{ + Title: r.findTitle(n), + Info: info, + URL: href, + } + results++ + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(doc) + close(ch) +} + +func (r *response) findTitle(n *html.Node) string { + var title string + var f func(*html.Node) + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "a" { + for _, a := range n.Attr { + if a.Key == class && strings.Contains(a.Val, "result__a") { + title = n.FirstChild.Data + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(n) + return title +} + +//nolint:gocognit +func (r *response) findInfo(n *html.Node) (string, string) { + var info string + var link string + var f func(*html.Node) + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "a" { + for _, a := range n.Attr { + if a.Key == class && strings.Contains(a.Val, "result__snippet") { + var b bytes.Buffer + _ = html.Render(&b, n) + + re := regexp.MustCompile("<.*?>") + info = html.UnescapeString(re.ReplaceAllString(b.String(), "")) + + for _, h := range n.Attr { + if h.Key == "href" { + link = "https:" + h.Val + break + } + } + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(n) + return info, link +} diff --git a/tools/duckduckgo/duckduckgo.go b/tools/duckduckgo/duckduckgo.go new file mode 100644 index 00000000..3c5f2b7d --- /dev/null +++ b/tools/duckduckgo/duckduckgo.go @@ -0,0 +1,77 @@ +package duckduckgo + +import ( + "context" + "fmt" + "net/http" + + "github.com/henomis/restclientgo" +) + +type Tool struct { + maxResults uint + userAgent string + restClient *restclientgo.RestClient +} + +type Input struct { + Query string `json:"python_code" jsonschema:"description=the query to search for"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Results []result `json:"results,omitempty"` +} + +type FnPrototype func(Input) Output + +func New() *Tool { + t := &Tool{ + maxResults: 1, + } + + restClient := restclientgo.New("https://html.duckduckgo.com"). + WithRequestModifier( + func(r *http.Request) *http.Request { + r.Header.Add("User-Agent", t.userAgent) + return r + }, + ) + + t.restClient = restClient + return t +} + +func (t *Tool) WithUserAgent(userAgent string) *Tool { + t.userAgent = userAgent + return t +} + +func (t *Tool) WithMaxResults(maxResults uint) *Tool { + t.maxResults = maxResults + return t +} + +func (t *Tool) Name() string { + return "duckduckgo" +} + +func (t *Tool) Description() string { + return "A tool that searches DuckDuckGo for a query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + req := &request{Query: i.Query} + res := &response{MaxResults: t.maxResults} + + err := t.restClient.Get(context.Background(), req, res) + if err != nil { + return Output{Error: fmt.Sprintf("failed to search DuckDuckGo: %v", err)} + } + + return Output{Results: res.Results} +} diff --git a/tools/python/python.go b/tools/python/python.go new file mode 100644 index 00000000..b7ccc2c7 --- /dev/null +++ b/tools/python/python.go @@ -0,0 +1,74 @@ +package python + +import ( + "bytes" + "fmt" + "os" + "os/exec" +) + +type Tool struct { + pythonPath string +} + +func New() *Tool { + return &Tool{ + pythonPath: "python3", + } +} + +func (t *Tool) WithPythonPath(pythonPath string) *Tool { + t.pythonPath = pythonPath + return t +} + +type Input struct { + PythonCode string `json:"python_code" jsonschema:"description=python code that prints the final result to stdout."` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "python" +} + +func (t *Tool) Description() string { + return "A tool that runs Python code using the Python interpreter. The code should print the final result to stdout." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + // Create a command to run the Python interpreter with the script. + cmd := exec.Command(t.pythonPath, "-c", i.PythonCode) + + os.WriteFile("/tmp/pippo.py", []byte(i.PythonCode), 0644) + + // Create a buffer to capture the output. + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + // Run the command. + err := cmd.Run() + if err != nil { + return Output{ + Error: fmt.Sprintf("failed to run script: %v, stderr: %v", err, stderr.String()), + } + } + + s := out.String() + _ = s + + // Return the output as a string. + return Output{Result: out.String()} +} From 6dfc0e0779efe6eb57de8067d611fd219251ffa1 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 01:01:48 +0200 Subject: [PATCH 03/15] fix --- tools/python/python.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/python/python.go b/tools/python/python.go index b7ccc2c7..af3d91a0 100644 --- a/tools/python/python.go +++ b/tools/python/python.go @@ -3,7 +3,6 @@ package python import ( "bytes" "fmt" - "os" "os/exec" ) @@ -50,8 +49,6 @@ func (t *Tool) fn(i Input) Output { // Create a command to run the Python interpreter with the script. cmd := exec.Command(t.pythonPath, "-c", i.PythonCode) - os.WriteFile("/tmp/pippo.py", []byte(i.PythonCode), 0644) - // Create a buffer to capture the output. var out bytes.Buffer var stderr bytes.Buffer From 55a2c577b9f97f397e1de050eb377a12b66a00ae Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 01:02:08 +0200 Subject: [PATCH 04/15] fix --- tools/python/python.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/python/python.go b/tools/python/python.go index af3d91a0..8a2fe0d8 100644 --- a/tools/python/python.go +++ b/tools/python/python.go @@ -63,9 +63,6 @@ func (t *Tool) fn(i Input) Output { } } - s := out.String() - _ = s - // Return the output as a string. return Output{Result: out.String()} } From c8cf9b3c535980b565f64d502d58adeb0f3b0918 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 11:11:00 +0200 Subject: [PATCH 05/15] fix: remove println --- tools/dalle/dalle.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/dalle/dalle.go b/tools/dalle/dalle.go index 0508d418..ebd75882 100644 --- a/tools/dalle/dalle.go +++ b/tools/dalle/dalle.go @@ -44,7 +44,5 @@ func (t *Tool) fn(i Input) Output { return Output{Error: fmt.Sprintf("error creating image: %v", err)} } - fmt.Println("Image created with url:", imageURL) - return Output{ImageURL: imageURL.(string)} } From 66a58f5481ab4d9e0e3a114c22cfb4116ec4fcba Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 15:38:54 +0200 Subject: [PATCH 06/15] chore: add tools --- tools/duckduckgo/duckduckgo.go | 4 +- tools/llm/llm.go | 60 +++++++++++++++++++++++++ tools/rag/rag.go | 54 +++++++++++++++++++++++ tools/tool_router/tool_router.go | 75 ++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 tools/llm/llm.go create mode 100644 tools/rag/rag.go create mode 100644 tools/tool_router/tool_router.go diff --git a/tools/duckduckgo/duckduckgo.go b/tools/duckduckgo/duckduckgo.go index 3c5f2b7d..80b2da55 100644 --- a/tools/duckduckgo/duckduckgo.go +++ b/tools/duckduckgo/duckduckgo.go @@ -15,7 +15,7 @@ type Tool struct { } type Input struct { - Query string `json:"python_code" jsonschema:"description=the query to search for"` + Query string `json:"query" jsonschema:"description=the query to search for"` } type Output struct { @@ -57,7 +57,7 @@ func (t *Tool) Name() string { } func (t *Tool) Description() string { - return "A tool that searches DuckDuckGo for a query." + return "A tool that searches on duckduckgo internet search engine for a query." } func (t *Tool) Fn() any { diff --git a/tools/llm/llm.go b/tools/llm/llm.go new file mode 100644 index 00000000..8b05bb55 --- /dev/null +++ b/tools/llm/llm.go @@ -0,0 +1,60 @@ +package llm + +import ( + "context" + + "github.com/henomis/lingoose/thread" +) + +type LLM interface { + Generate(context.Context, *thread.Thread) error +} + +type Tool struct { + llm LLM +} + +func New(llm LLM) *Tool { + return &Tool{ + llm: llm, + } +} + +type Input struct { + Query string `json:"query" jsonschema:"description=user query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "llm" +} + +func (t *Tool) Description() string { + return "A tool that uses a language model to generate a response to a user query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + th := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent(i.Query), + ), + ) + + err := t.llm.Generate(context.Background(), th) + if err != nil { + return Output{Error: err.Error()} + } + + return Output{Result: th.LastMessage().Contents[0].AsString()} +} diff --git a/tools/rag/rag.go b/tools/rag/rag.go new file mode 100644 index 00000000..13f05682 --- /dev/null +++ b/tools/rag/rag.go @@ -0,0 +1,54 @@ +package rag + +import ( + "context" + "strings" + + "github.com/henomis/lingoose/rag" +) + +type Tool struct { + rag *rag.RAG + topic string +} + +func New(rag *rag.RAG, topic string) *Tool { + return &Tool{ + rag: rag, + topic: topic, + } +} + +type Input struct { + Query string `json:"rag_query" jsonschema:"description=search query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "rag" +} + +func (t *Tool) Description() string { + return "A tool that searches information ONLY for this topic: " + t.topic + ". DO NOT use this tool for other topics." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + results, err := t.rag.Retrieve(context.Background(), i.Query) + if err != nil { + return Output{Error: err.Error()} + } + + // Return the output as a string. + return Output{Result: strings.Join(results, "\n")} +} diff --git a/tools/tool_router/tool_router.go b/tools/tool_router/tool_router.go new file mode 100644 index 00000000..529dfab2 --- /dev/null +++ b/tools/tool_router/tool_router.go @@ -0,0 +1,75 @@ +package toolrouter + +import ( + "context" + + "github.com/henomis/lingoose/thread" +) + +type TTool interface { + Description() string + Name() string + Fn() any +} + +type Tool struct { + llm LLM + tools []TTool +} + +type LLM interface { + Generate(context.Context, *thread.Thread) error +} + +func New(llm LLM, tools ...TTool) *Tool { + return &Tool{ + tools: tools, + llm: llm, + } +} + +type Input struct { + Query string `json:"query" jsonschema:"description=user query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result any `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "query_router" +} + +func (t *Tool) Description() string { + return "A tool that select the right tool to answer to an user query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + query := "Here's a list of available tools:\n\n" + for _, tool := range t.tools { + query += "Name: " + tool.Name() + "\nDescription: " + tool.Description() + "\n\n" + } + + query += "\nPlease select the right tool that can better answer the query '" + i.Query + "'. Give me only the name of the tool, nothing else." + + th := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent(query), + ), + ) + + err := t.llm.Generate(context.Background(), th) + if err != nil { + return Output{Error: err.Error()} + } + + return Output{Result: th.LastMessage().Contents[0].AsString()} +} From 863342d6786232bc6690dd0c01d67d8dde4920cb Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 22 May 2024 15:40:18 +0200 Subject: [PATCH 07/15] chore: fix example --- examples/llm/openai/tools/rag/main.go | 63 +++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/llm/openai/tools/rag/main.go diff --git a/examples/llm/openai/tools/rag/main.go b/examples/llm/openai/tools/rag/main.go new file mode 100644 index 00000000..9bc147da --- /dev/null +++ b/examples/llm/openai/tools/rag/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "fmt" + "os" + + openaiembedder "github.com/henomis/lingoose/embedder/openai" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/vectordb/jsondb" + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/rag" + "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/tools/duckduckgo" + ragtool "github.com/henomis/lingoose/tools/rag" +) + +func main() { + + rag := rag.New( + index.New( + jsondb.New().WithPersist("index.json"), + openaiembedder.New(openaiembedder.AdaEmbeddingV2), + ), + ).WithChunkSize(1000).WithChunkOverlap(0) + + _, err := os.Stat("index.json") + if os.IsNotExist(err) { + err = rag.AddSources(context.Background(), "state_of_the_union.txt") + if err != nil { + panic(err) + } + } + + newStr := func(str string) *string { + return &str + } + llm := openai.New().WithModel(openai.GPT4o).WithToolChoice(newStr("auto")).WithTools( + ragtool.New(rag, "US covid vaccines"), + duckduckgo.New().WithMaxResults(5), + ) + + topics := []string{ + "how many covid vaccine doses US has donated to other countries", + "apple stock price", + } + + for _, topic := range topics { + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("I would like to know something about " + topic + "."), + ), + ) + + llm.Generate(context.Background(), t) + if t.LastMessage().Role == thread.RoleTool { + llm.Generate(context.Background(), t) + } + + fmt.Println(t) + } + +} From a8eb5e54b0591739c308fcf8254ae8e50c65219f Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 23 May 2024 13:01:55 +0200 Subject: [PATCH 08/15] fix line lenght --- tools/tool_router/tool_router.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/tool_router/tool_router.go b/tools/tool_router/tool_router.go index 529dfab2..b06e763a 100644 --- a/tools/tool_router/tool_router.go +++ b/tools/tool_router/tool_router.go @@ -58,7 +58,8 @@ func (t *Tool) fn(i Input) Output { query += "Name: " + tool.Name() + "\nDescription: " + tool.Description() + "\n\n" } - query += "\nPlease select the right tool that can better answer the query '" + i.Query + "'. Give me only the name of the tool, nothing else." + query += "\nPlease select the right tool that can better answer the query '" + i.Query + + "'. Give me only the name of the tool, nothing else." th := thread.New().AddMessage( thread.NewUserMessage().AddContent( From a8459c0162fb117fe02357b09a1377c288ced063 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 23 May 2024 14:15:44 +0200 Subject: [PATCH 09/15] fix grammars --- examples/llm/openai/tools/python/main.go | 2 +- examples/llm/openai/tools/rag/main.go | 4 ++-- tools/duckduckgo/duckduckgo.go | 2 +- tools/tool_router/tool_router.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/llm/openai/tools/python/main.go b/examples/llm/openai/tools/python/main.go index 2d4bd7c0..1e133410 100644 --- a/examples/llm/openai/tools/python/main.go +++ b/examples/llm/openai/tools/python/main.go @@ -19,7 +19,7 @@ func main() { t := thread.New().AddMessage( thread.NewUserMessage().AddContent( - thread.NewTextContent("calculate reverse string of 'ailatiditalia', don't try to guess, let's use appropriate tools"), + thread.NewTextContent("calculate reverse string of 'ailatiditalia', don't try to guess, let's use appropriate tool"), ), ) diff --git a/examples/llm/openai/tools/rag/main.go b/examples/llm/openai/tools/rag/main.go index 9bc147da..0872d388 100644 --- a/examples/llm/openai/tools/rag/main.go +++ b/examples/llm/openai/tools/rag/main.go @@ -42,13 +42,13 @@ func main() { topics := []string{ "how many covid vaccine doses US has donated to other countries", - "apple stock price", + "who's the lingoose github project author", } for _, topic := range topics { t := thread.New().AddMessage( thread.NewUserMessage().AddContent( - thread.NewTextContent("I would like to know something about " + topic + "."), + thread.NewTextContent("Please tell me " + topic + "."), ), ) diff --git a/tools/duckduckgo/duckduckgo.go b/tools/duckduckgo/duckduckgo.go index 80b2da55..9601e3c0 100644 --- a/tools/duckduckgo/duckduckgo.go +++ b/tools/duckduckgo/duckduckgo.go @@ -57,7 +57,7 @@ func (t *Tool) Name() string { } func (t *Tool) Description() string { - return "A tool that searches on duckduckgo internet search engine for a query." + return "A tool that uses the DuckDuckGo internet search engine for a query." } func (t *Tool) Fn() any { diff --git a/tools/tool_router/tool_router.go b/tools/tool_router/tool_router.go index b06e763a..ec12a231 100644 --- a/tools/tool_router/tool_router.go +++ b/tools/tool_router/tool_router.go @@ -44,7 +44,7 @@ func (t *Tool) Name() string { } func (t *Tool) Description() string { - return "A tool that select the right tool to answer to an user query." + return "A tool that select the right tool to answer to user queries." } func (t *Tool) Fn() any { From 2fc04f870cd72ab10bb7974602ac19e23ea847ab Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 24 May 2024 11:01:33 +0200 Subject: [PATCH 10/15] chore: add serpapi --- examples/llm/openai/tools/rag/main.go | 6 +- examples/tools/serpapi/main.go | 15 +++ tools/python/python.go | 2 +- tools/rag/rag.go | 2 +- tools/serpapi/api.go | 186 ++++++++++++++++++++++++++ tools/serpapi/serpapi.go | 90 +++++++++++++ 6 files changed, 296 insertions(+), 5 deletions(-) create mode 100644 examples/tools/serpapi/main.go create mode 100644 tools/serpapi/api.go create mode 100644 tools/serpapi/serpapi.go diff --git a/examples/llm/openai/tools/rag/main.go b/examples/llm/openai/tools/rag/main.go index 0872d388..6cf90fec 100644 --- a/examples/llm/openai/tools/rag/main.go +++ b/examples/llm/openai/tools/rag/main.go @@ -11,8 +11,8 @@ import ( "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/rag" "github.com/henomis/lingoose/thread" - "github.com/henomis/lingoose/tools/duckduckgo" ragtool "github.com/henomis/lingoose/tools/rag" + "github.com/henomis/lingoose/tools/serpapi" ) func main() { @@ -37,12 +37,12 @@ func main() { } llm := openai.New().WithModel(openai.GPT4o).WithToolChoice(newStr("auto")).WithTools( ragtool.New(rag, "US covid vaccines"), - duckduckgo.New().WithMaxResults(5), + serpapi.New(), ) topics := []string{ "how many covid vaccine doses US has donated to other countries", - "who's the lingoose github project author", + "who's the author of LinGoose github project", } for _, topic := range topics { diff --git a/examples/tools/serpapi/main.go b/examples/tools/serpapi/main.go new file mode 100644 index 00000000..d0578d3a --- /dev/null +++ b/examples/tools/serpapi/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/serpapi" +) + +func main() { + + t := serpapi.New() + f := t.Fn().(serpapi.FnPrototype) + + fmt.Println(f(serpapi.Input{Query: "Simone Vellei"})) +} diff --git a/tools/python/python.go b/tools/python/python.go index 8a2fe0d8..2ac686f2 100644 --- a/tools/python/python.go +++ b/tools/python/python.go @@ -30,7 +30,7 @@ type Output struct { Result string `json:"result,omitempty"` } -type FnPrototype func(Input) Output +type FnPrototype = func(Input) Output func (t *Tool) Name() string { return "python" diff --git a/tools/rag/rag.go b/tools/rag/rag.go index 13f05682..54a5554f 100644 --- a/tools/rag/rag.go +++ b/tools/rag/rag.go @@ -28,7 +28,7 @@ type Output struct { Result string `json:"result,omitempty"` } -type FnPrototype func(Input) Output +type FnPrototype = func(Input) Output func (t *Tool) Name() string { return "rag" diff --git a/tools/serpapi/api.go b/tools/serpapi/api.go new file mode 100644 index 00000000..d36d36b5 --- /dev/null +++ b/tools/serpapi/api.go @@ -0,0 +1,186 @@ +package serpapi + +import ( + "encoding/json" + "io" + + "github.com/henomis/restclientgo" +) + +type request struct { + Query string + GoogleDomain string + CountryCode string + LanguageCode string + ApiKey string +} + +type response struct { + HTTPStatusCode int + Map map[string]interface{} + RawBody []byte + apiResponse apiResponse + Results []result +} + +type apiResponse struct { + SearchMetadata SearchMetadata `json:"search_metadata"` + SearchParameters SearchParameters `json:"search_parameters"` + SearchInformation SearchInformation `json:"search_information"` + InlineImagesSuggestedSearches []InlineImagesSuggestedSearches `json:"inline_images_suggested_searches"` + InlineImages []InlineImages `json:"inline_images"` + AnswerBox AnswerBox `json:"answer_box"` + OrganicResults []OrganicResults `json:"organic_results"` + Pagination Pagination `json:"pagination"` + SerpapiPagination SerpapiPagination `json:"serpapi_pagination"` +} +type SearchMetadata struct { + ID string `json:"id"` + Status string `json:"status"` + JSONEndpoint string `json:"json_endpoint"` + CreatedAt string `json:"created_at"` + ProcessedAt string `json:"processed_at"` + GoogleURL string `json:"google_url"` + RawHTMLFile string `json:"raw_html_file"` + TotalTimeTaken float64 `json:"total_time_taken"` +} +type SearchParameters struct { + Engine string `json:"engine"` + Q string `json:"q"` + GoogleDomain string `json:"google_domain"` + Hl string `json:"hl"` + Gl string `json:"gl"` + Device string `json:"device"` +} +type SearchInformation struct { + QueryDisplayed string `json:"query_displayed"` + TotalResults int `json:"total_results"` + TimeTakenDisplayed float64 `json:"time_taken_displayed"` + OrganicResultsState string `json:"organic_results_state"` +} +type InlineImagesSuggestedSearches struct { + Name string `json:"name"` + Link string `json:"link"` + Uds string `json:"uds"` + Q string `json:"q"` + SerpapiLink string `json:"serpapi_link"` + Thumbnail string `json:"thumbnail"` +} +type InlineImages struct { + Link string `json:"link"` + Source string `json:"source"` + Thumbnail string `json:"thumbnail"` + Original string `json:"original"` + Title string `json:"title"` + SourceName string `json:"source_name"` +} +type AnswerBox struct { + Type string `json:"type"` + Title string `json:"title"` + Thumbnail string `json:"thumbnail"` +} +type Top struct { + Extensions []string `json:"extensions"` +} +type RichSnippet struct { + Top Top `json:"top"` +} +type OrganicResults struct { + Position int `json:"position"` + Title string `json:"title"` + Link string `json:"link"` + RedirectLink string `json:"redirect_link"` + DisplayedLink string `json:"displayed_link"` + Thumbnail string `json:"thumbnail,omitempty"` + Favicon string `json:"favicon"` + Snippet string `json:"snippet"` + Source string `json:"source"` + RichSnippet RichSnippet `json:"rich_snippet,omitempty"` + SnippetHighlightedWords []string `json:"snippet_highlighted_words,omitempty"` +} +type OtherPages struct { + Num2 string `json:"2"` + Num3 string `json:"3"` + Num4 string `json:"4"` + Num5 string `json:"5"` +} +type Pagination struct { + Current int `json:"current"` + Next string `json:"next"` + OtherPages OtherPages `json:"other_pages"` +} +type SerpapiPagination struct { + Current int `json:"current"` + NextLink string `json:"next_link"` + Next string `json:"next"` + OtherPages OtherPages `json:"other_pages"` +} + +type result struct { + Title string + Info string + URL string +} + +func (r *request) Path() (string, error) { + urlValues := restclientgo.NewURLValues() + urlValues.Add("q", &r.Query) + urlValues.Add("api_key", &r.ApiKey) + + if r.GoogleDomain != "" { + urlValues.Add("google_domain", &r.GoogleDomain) + } + + if r.CountryCode != "" { + urlValues.Add("gl", &r.CountryCode) + } + + if r.LanguageCode != "" { + urlValues.Add("hl", &r.LanguageCode) + } + + params := urlValues.Encode() + + return "/search?" + params, nil +} + +func (r *request) Encode() (io.Reader, error) { + return nil, nil +} + +func (r *request) ContentType() string { + return "" +} + +func (r *response) Decode(body io.Reader) error { + err := json.NewDecoder(body).Decode(&r.apiResponse) + if err != nil { + return err + } + + for _, res := range r.apiResponse.OrganicResults { + r.Results = append(r.Results, result{ + Title: res.Title, + Info: res.Snippet, + URL: res.Link, + }) + } + + return nil +} + +func (r *response) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) + return nil +} + +func (r *response) AcceptContentType() string { + return "application/json" +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } diff --git a/tools/serpapi/serpapi.go b/tools/serpapi/serpapi.go new file mode 100644 index 00000000..c524bff5 --- /dev/null +++ b/tools/serpapi/serpapi.go @@ -0,0 +1,90 @@ +package serpapi + +import ( + "context" + "fmt" + "os" + + "github.com/henomis/restclientgo" +) + +type Tool struct { + restClient *restclientgo.RestClient + googleDomain string + countryCode string + languageCode string + apiKey string +} + +type Input struct { + Query string `json:"query" jsonschema:"description=the query to search for"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Results []result `json:"results,omitempty"` +} + +type FnPrototype = func(Input) Output + +func New() *Tool { + t := &Tool{ + apiKey: os.Getenv("SERPAPI_API_KEY"), + restClient: restclientgo.New("https://serpapi.com"), + googleDomain: "google.com", + countryCode: "us", + languageCode: "en", + } + + return t +} + +func (t *Tool) WithGoogleDomain(googleDomain string) *Tool { + t.googleDomain = googleDomain + return t +} + +func (t *Tool) WithCountryCode(countryCode string) *Tool { + t.countryCode = countryCode + return t +} + +func (t *Tool) WithLanguageCode(languageCode string) *Tool { + t.languageCode = languageCode + return t +} + +func (t *Tool) WithApiKey(apiKey string) *Tool { + t.apiKey = apiKey + return t +} + +func (t *Tool) Name() string { + return "google" +} + +func (t *Tool) Description() string { + return "A tool that uses the Google internet search engine for a query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + req := &request{ + Query: i.Query, + GoogleDomain: t.googleDomain, + CountryCode: t.countryCode, + LanguageCode: t.languageCode, + ApiKey: t.apiKey, + } + res := &response{} + + err := t.restClient.Get(context.Background(), req, res) + if err != nil { + return Output{Error: fmt.Sprintf("failed to search serpapi: %v", err)} + } + + return Output{Results: res.Results} +} From 290da52606bd79d8e81592902f3206e9bc73c2f7 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 24 May 2024 11:03:37 +0200 Subject: [PATCH 11/15] fix: remove unused structs --- tools/serpapi/api.go | 75 +++----------------------------------------- 1 file changed, 4 insertions(+), 71 deletions(-) diff --git a/tools/serpapi/api.go b/tools/serpapi/api.go index d36d36b5..1a7d7ffe 100644 --- a/tools/serpapi/api.go +++ b/tools/serpapi/api.go @@ -24,67 +24,17 @@ type response struct { } type apiResponse struct { - SearchMetadata SearchMetadata `json:"search_metadata"` - SearchParameters SearchParameters `json:"search_parameters"` - SearchInformation SearchInformation `json:"search_information"` - InlineImagesSuggestedSearches []InlineImagesSuggestedSearches `json:"inline_images_suggested_searches"` - InlineImages []InlineImages `json:"inline_images"` - AnswerBox AnswerBox `json:"answer_box"` - OrganicResults []OrganicResults `json:"organic_results"` - Pagination Pagination `json:"pagination"` - SerpapiPagination SerpapiPagination `json:"serpapi_pagination"` -} -type SearchMetadata struct { - ID string `json:"id"` - Status string `json:"status"` - JSONEndpoint string `json:"json_endpoint"` - CreatedAt string `json:"created_at"` - ProcessedAt string `json:"processed_at"` - GoogleURL string `json:"google_url"` - RawHTMLFile string `json:"raw_html_file"` - TotalTimeTaken float64 `json:"total_time_taken"` -} -type SearchParameters struct { - Engine string `json:"engine"` - Q string `json:"q"` - GoogleDomain string `json:"google_domain"` - Hl string `json:"hl"` - Gl string `json:"gl"` - Device string `json:"device"` -} -type SearchInformation struct { - QueryDisplayed string `json:"query_displayed"` - TotalResults int `json:"total_results"` - TimeTakenDisplayed float64 `json:"time_taken_displayed"` - OrganicResultsState string `json:"organic_results_state"` -} -type InlineImagesSuggestedSearches struct { - Name string `json:"name"` - Link string `json:"link"` - Uds string `json:"uds"` - Q string `json:"q"` - SerpapiLink string `json:"serpapi_link"` - Thumbnail string `json:"thumbnail"` -} -type InlineImages struct { - Link string `json:"link"` - Source string `json:"source"` - Thumbnail string `json:"thumbnail"` - Original string `json:"original"` - Title string `json:"title"` - SourceName string `json:"source_name"` -} -type AnswerBox struct { - Type string `json:"type"` - Title string `json:"title"` - Thumbnail string `json:"thumbnail"` + OrganicResults []OrganicResults `json:"organic_results"` } + type Top struct { Extensions []string `json:"extensions"` } + type RichSnippet struct { Top Top `json:"top"` } + type OrganicResults struct { Position int `json:"position"` Title string `json:"title"` @@ -98,23 +48,6 @@ type OrganicResults struct { RichSnippet RichSnippet `json:"rich_snippet,omitempty"` SnippetHighlightedWords []string `json:"snippet_highlighted_words,omitempty"` } -type OtherPages struct { - Num2 string `json:"2"` - Num3 string `json:"3"` - Num4 string `json:"4"` - Num5 string `json:"5"` -} -type Pagination struct { - Current int `json:"current"` - Next string `json:"next"` - OtherPages OtherPages `json:"other_pages"` -} -type SerpapiPagination struct { - Current int `json:"current"` - NextLink string `json:"next_link"` - Next string `json:"next"` - OtherPages OtherPages `json:"other_pages"` -} type result struct { Title string From 88a2bce4998ff6586754bbf3541823c94cd33da5 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 24 May 2024 11:04:51 +0200 Subject: [PATCH 12/15] fix linting --- tools/serpapi/api.go | 4 ++-- tools/serpapi/serpapi.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/serpapi/api.go b/tools/serpapi/api.go index 1a7d7ffe..944b71fe 100644 --- a/tools/serpapi/api.go +++ b/tools/serpapi/api.go @@ -12,7 +12,7 @@ type request struct { GoogleDomain string CountryCode string LanguageCode string - ApiKey string + APIKey string } type response struct { @@ -58,7 +58,7 @@ type result struct { func (r *request) Path() (string, error) { urlValues := restclientgo.NewURLValues() urlValues.Add("q", &r.Query) - urlValues.Add("api_key", &r.ApiKey) + urlValues.Add("api_key", &r.APIKey) if r.GoogleDomain != "" { urlValues.Add("google_domain", &r.GoogleDomain) diff --git a/tools/serpapi/serpapi.go b/tools/serpapi/serpapi.go index c524bff5..e4dfb3a8 100644 --- a/tools/serpapi/serpapi.go +++ b/tools/serpapi/serpapi.go @@ -54,7 +54,7 @@ func (t *Tool) WithLanguageCode(languageCode string) *Tool { return t } -func (t *Tool) WithApiKey(apiKey string) *Tool { +func (t *Tool) WithAPIKey(apiKey string) *Tool { t.apiKey = apiKey return t } @@ -77,7 +77,7 @@ func (t *Tool) fn(i Input) Output { GoogleDomain: t.googleDomain, CountryCode: t.countryCode, LanguageCode: t.languageCode, - ApiKey: t.apiKey, + APIKey: t.apiKey, } res := &response{} From c381b981228f29e0a659192a855cc77b43a0392e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 24 May 2024 11:21:18 +0200 Subject: [PATCH 13/15] chore: I'm not very happy with that but it's a way to use context in tools --- tools/dalle/dalle.go | 10 +++++++++- tools/duckduckgo/duckduckgo.go | 10 +++++++++- tools/llm/llm.go | 10 +++++++++- tools/rag/rag.go | 10 +++++++++- tools/serpapi/serpapi.go | 10 +++++++++- tools/tool_router/tool_router.go | 10 +++++++++- 6 files changed, 54 insertions(+), 6 deletions(-) diff --git a/tools/dalle/dalle.go b/tools/dalle/dalle.go index ebd75882..03e3cb5e 100644 --- a/tools/dalle/dalle.go +++ b/tools/dalle/dalle.go @@ -3,10 +3,15 @@ package dalle import ( "context" "fmt" + "time" "github.com/henomis/lingoose/transformer" ) +const ( + defaultTimeoutInSeconds = 60 +) + type Tool struct { } @@ -38,8 +43,11 @@ func (t *Tool) Fn() any { } func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512) - imageURL, err := d.Transform(context.Background(), i.Description) + imageURL, err := d.Transform(ctx, i.Description) if err != nil { return Output{Error: fmt.Sprintf("error creating image: %v", err)} } diff --git a/tools/duckduckgo/duckduckgo.go b/tools/duckduckgo/duckduckgo.go index 9601e3c0..ec374942 100644 --- a/tools/duckduckgo/duckduckgo.go +++ b/tools/duckduckgo/duckduckgo.go @@ -4,10 +4,15 @@ import ( "context" "fmt" "net/http" + "time" "github.com/henomis/restclientgo" ) +const ( + defaultTimeoutInSeconds = 60 +) + type Tool struct { maxResults uint userAgent string @@ -65,10 +70,13 @@ func (t *Tool) Fn() any { } func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + req := &request{Query: i.Query} res := &response{MaxResults: t.maxResults} - err := t.restClient.Get(context.Background(), req, res) + err := t.restClient.Get(ctx, req, res) if err != nil { return Output{Error: fmt.Sprintf("failed to search DuckDuckGo: %v", err)} } diff --git a/tools/llm/llm.go b/tools/llm/llm.go index 8b05bb55..4f6190cb 100644 --- a/tools/llm/llm.go +++ b/tools/llm/llm.go @@ -2,10 +2,15 @@ package llm import ( "context" + "time" "github.com/henomis/lingoose/thread" ) +const ( + defaultTimeoutInMinutes = 6 +) + type LLM interface { Generate(context.Context, *thread.Thread) error } @@ -45,13 +50,16 @@ func (t *Tool) Fn() any { //nolint:gosec func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + th := thread.New().AddMessage( thread.NewUserMessage().AddContent( thread.NewTextContent(i.Query), ), ) - err := t.llm.Generate(context.Background(), th) + err := t.llm.Generate(ctx, th) if err != nil { return Output{Error: err.Error()} } diff --git a/tools/rag/rag.go b/tools/rag/rag.go index 54a5554f..c46788cf 100644 --- a/tools/rag/rag.go +++ b/tools/rag/rag.go @@ -3,10 +3,15 @@ package rag import ( "context" "strings" + "time" "github.com/henomis/lingoose/rag" ) +const ( + defaultTimeoutInMinutes = 6 +) + type Tool struct { rag *rag.RAG topic string @@ -44,7 +49,10 @@ func (t *Tool) Fn() any { //nolint:gosec func (t *Tool) fn(i Input) Output { - results, err := t.rag.Retrieve(context.Background(), i.Query) + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + + results, err := t.rag.Retrieve(ctx, i.Query) if err != nil { return Output{Error: err.Error()} } diff --git a/tools/serpapi/serpapi.go b/tools/serpapi/serpapi.go index e4dfb3a8..9637e691 100644 --- a/tools/serpapi/serpapi.go +++ b/tools/serpapi/serpapi.go @@ -4,10 +4,15 @@ import ( "context" "fmt" "os" + "time" "github.com/henomis/restclientgo" ) +const ( + defaultTimeoutInSeconds = 60 +) + type Tool struct { restClient *restclientgo.RestClient googleDomain string @@ -72,6 +77,9 @@ func (t *Tool) Fn() any { } func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + req := &request{ Query: i.Query, GoogleDomain: t.googleDomain, @@ -81,7 +89,7 @@ func (t *Tool) fn(i Input) Output { } res := &response{} - err := t.restClient.Get(context.Background(), req, res) + err := t.restClient.Get(ctx, req, res) if err != nil { return Output{Error: fmt.Sprintf("failed to search serpapi: %v", err)} } diff --git a/tools/tool_router/tool_router.go b/tools/tool_router/tool_router.go index ec12a231..17d35bcd 100644 --- a/tools/tool_router/tool_router.go +++ b/tools/tool_router/tool_router.go @@ -2,10 +2,15 @@ package toolrouter import ( "context" + "time" "github.com/henomis/lingoose/thread" ) +const ( + defaultTimeoutInMinutes = 6 +) + type TTool interface { Description() string Name() string @@ -53,6 +58,9 @@ func (t *Tool) Fn() any { //nolint:gosec func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + query := "Here's a list of available tools:\n\n" for _, tool := range t.tools { query += "Name: " + tool.Name() + "\nDescription: " + tool.Description() + "\n\n" @@ -67,7 +75,7 @@ func (t *Tool) fn(i Input) Output { ), ) - err := t.llm.Generate(context.Background(), th) + err := t.llm.Generate(ctx, th) if err != nil { return Output{Error: err.Error()} } From e22c3b4f192b4be9fd885c14c549e60717f3e83f Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 25 May 2024 13:39:30 +0200 Subject: [PATCH 14/15] chore: add shell tool --- examples/llm/openai/tools/rag/main.go | 9 ++-- examples/tools/bash/main.go | 16 +++++++ tools/shell/shell.go | 68 +++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 examples/tools/bash/main.go create mode 100644 tools/shell/shell.go diff --git a/examples/llm/openai/tools/rag/main.go b/examples/llm/openai/tools/rag/main.go index 6cf90fec..29deadc9 100644 --- a/examples/llm/openai/tools/rag/main.go +++ b/examples/llm/openai/tools/rag/main.go @@ -13,6 +13,7 @@ import ( "github.com/henomis/lingoose/thread" ragtool "github.com/henomis/lingoose/tools/rag" "github.com/henomis/lingoose/tools/serpapi" + "github.com/henomis/lingoose/tools/shell" ) func main() { @@ -38,17 +39,19 @@ func main() { llm := openai.New().WithModel(openai.GPT4o).WithToolChoice(newStr("auto")).WithTools( ragtool.New(rag, "US covid vaccines"), serpapi.New(), + shell.New(), ) topics := []string{ - "how many covid vaccine doses US has donated to other countries", - "who's the author of LinGoose github project", + "how many covid vaccine doses US has donated to other countries.", + "who's the author of LinGoose github project.", + "which process is consuming the most memory.", } for _, topic := range topics { t := thread.New().AddMessage( thread.NewUserMessage().AddContent( - thread.NewTextContent("Please tell me " + topic + "."), + thread.NewTextContent("Please tell me " + topic), ), ) diff --git a/examples/tools/bash/main.go b/examples/tools/bash/main.go new file mode 100644 index 00000000..3ec73377 --- /dev/null +++ b/examples/tools/bash/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/shell" +) + +func main() { + t := shell.New() + + bashScript := `echo "Hello from $SHELL!"` + f := t.Fn().(shell.FnPrototype) + + fmt.Println(f(shell.Input{BashScript: bashScript})) +} diff --git a/tools/shell/shell.go b/tools/shell/shell.go new file mode 100644 index 00000000..52ba6b89 --- /dev/null +++ b/tools/shell/shell.go @@ -0,0 +1,68 @@ +package shell + +import ( + "bytes" + "fmt" + "os/exec" +) + +type Tool struct { + shell string +} + +func New() *Tool { + return &Tool{ + shell: "bash", + } +} + +func (t *Tool) WithShell(shell string) *Tool { + t.shell = shell + return t +} + +type Input struct { + BashScript string `json:"bash_code" jsonschema:"description=shell script"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype = func(Input) Output + +func (t *Tool) Name() string { + return "bash" +} + +func (t *Tool) Description() string { + return "A tool that runs a shell script using the " + t.shell + " interpreter. Use it to interact with the OS." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + // Create a command to run the Bash interpreter with the script. + cmd := exec.Command(t.shell, "-c", i.BashScript) + + // Create a buffer to capture the output. + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + // Run the command. + err := cmd.Run() + if err != nil { + return Output{ + Error: fmt.Sprintf("failed to run script: %v, stderr: %v", err, stderr.String()), + } + } + + // Return the output as a string. + return Output{Result: out.String()} +} From 17babb8f2ba40d5a2456edee8d914e1fe0e4fbba Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 25 May 2024 13:44:26 +0200 Subject: [PATCH 15/15] chore: refactor shell tool --- examples/tools/{bash => shell}/main.go | 0 tools/shell/shell.go | 27 ++++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) rename examples/tools/{bash => shell}/main.go (100%) diff --git a/examples/tools/bash/main.go b/examples/tools/shell/main.go similarity index 100% rename from examples/tools/bash/main.go rename to examples/tools/shell/main.go diff --git a/tools/shell/shell.go b/tools/shell/shell.go index 52ba6b89..c5c5e308 100644 --- a/tools/shell/shell.go +++ b/tools/shell/shell.go @@ -7,12 +7,14 @@ import ( ) type Tool struct { - shell string + shell string + askForConfirm bool } func New() *Tool { return &Tool{ - shell: "bash", + shell: "bash", + askForConfirm: true, } } @@ -21,6 +23,11 @@ func (t *Tool) WithShell(shell string) *Tool { return t } +func (t *Tool) WithAskForConfirm(askForConfirm bool) *Tool { + t.askForConfirm = askForConfirm + return t +} + type Input struct { BashScript string `json:"bash_code" jsonschema:"description=shell script"` } @@ -46,6 +53,22 @@ func (t *Tool) Fn() any { //nolint:gosec func (t *Tool) fn(i Input) Output { + // Ask for confirmation if the flag is set. + if t.askForConfirm { + fmt.Println("Are you sure you want to run the following script?") + fmt.Println("-------------------------------------------------") + fmt.Println(i.BashScript) + fmt.Println("-------------------------------------------------") + fmt.Print("Type 'yes' to confirm > ") + var confirm string + fmt.Scanln(&confirm) + if confirm != "yes" { + return Output{ + Error: "script execution aborted", + } + } + } + // Create a command to run the Bash interpreter with the script. cmd := exec.Command(t.shell, "-c", i.BashScript)