Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion internal/agent/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,60 @@ func (a *defaultAgent) getTools() []*llm.FunctionDefinition {
}
}

func (a *defaultAgent) proposeInferenceWithAIAssist(ctx context.Context, initialErr error) (*schema.StrategyOneOf, error) {
log.Println("Getting repository hint from AI...")
prompt := []string{
fmt.Sprintf("Based on the following inference failure error \"%v\" for package '%s', find the correct source code repository URL.", initialErr, a.t.Package),
"Just return the URL WITHOUT any additional text or formatting.",
"For example, for the package 'org.apache.camel:camel-support', return 'https://github.com/apache/camel' not 'https://github.com/apache/camel/tree/main/core/camel-support'.",
"Use the tools you have at your disposal to find the URL.",
"Finally, if you don't find the URL, just return an empty string.",
}
repoURL, err := llm.GenerateTextContent(ctx, a.deps.Client, llm.GeminiPro, &genai.GenerateContentConfig{
Temperature: genai.Ptr(float32(0.0)),
Tools: []*genai.Tool{
{GoogleSearch: &genai.GoogleSearch{}},
},
}, genai.NewPartFromText(strings.Join(prompt, "\n")))
if err != nil {
return nil, errors.Wrap(err, "getting AI repo hint")
}
if repoURL == "" {
return nil, errors.Wrap(initialErr, "AI could not find a repository hint")
}
log.Printf("AI suggested repo hint: %s", repoURL)
req := schema.InferenceRequest{
Ecosystem: a.t.Ecosystem,
Package: a.t.Package,
Version: a.t.Version,
Artifact: a.t.Artifact,
StrategyHint: &schema.StrategyOneOf{
LocationHint: &rebuild.LocationHint{
Location: rebuild.Location{
Repo: repoURL,
},
},
},
}
wt := memfs.New()
str := memory.NewStorage()
s, err := inferenceservice.Infer(
ctx,
req,
&inferenceservice.InferDeps{
HTTPClient: http.DefaultClient,
GitCache: nil,
RepoOptF: func() *gitx.RepositoryOptions {
return &gitx.RepositoryOptions{
Worktree: wt,
Storer: str,
}
},
},
)
return s, errors.Wrap(err, "AI-assisted inference failed")
}

func (a *defaultAgent) proposeNormalInference(ctx context.Context) (*schema.StrategyOneOf, error) {
wt := memfs.New()
str := memory.NewStorage()
Expand All @@ -362,7 +416,12 @@ func (a *defaultAgent) proposeNormalInference(ctx context.Context) (*schema.Stra
},
)
if err != nil {
return nil, errors.Wrap(err, "inferring initial strategy")
log.Printf("Normal inference failed: %v", err)
s, err = a.proposeInferenceWithAIAssist(ctx, err)
if err != nil {
return nil, errors.Wrap(err, "AI-assisted inference failed")
}
log.Println("AI-assisted inference succeeded.")
}
a.repo, err = git.Open(str, wt)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/agent/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type AgentDeps struct {
MetadataBucket string
LogsBucket string
MaxTurns int
Client *genai.Client
}

type Agent interface {
Expand Down Expand Up @@ -86,6 +87,7 @@ func doSession(ctx context.Context, req RunSessionReq, deps RunSessionDeps) *sch
MetadataBucket: deps.MetadataBucket,
LogsBucket: deps.LogsBucket,
MaxTurns: 10,
Client: deps.Client,
})
var err error
a.deps.Chat, err = llm.NewChat(ctx, deps.Client, llm.GeminiPro, config, &llm.ChatOpts{Tools: a.getTools()})
Expand Down
3 changes: 3 additions & 0 deletions internal/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package llm
import (
"context"
"encoding/json"
"log"

"github.com/pkg/errors"
"google.golang.org/genai"
Expand Down Expand Up @@ -61,6 +62,7 @@ type ScriptResponse struct {

func GenerateTextContent(ctx context.Context, client *genai.Client, model string, config *genai.GenerateContentConfig, prompt ...*genai.Part) (string, error) {
contents := []*genai.Content{{Parts: prompt, Role: "user"}}
log.Printf("%s\n\n", FormatContent(*contents[0]))
resp, err := client.Models.GenerateContent(ctx, model, contents, config)
if err != nil {
return "", errors.Wrap(err, "failed to generate content")
Expand All @@ -77,6 +79,7 @@ func GenerateTextContent(ctx context.Context, client *genai.Client, model string
return "", errors.New("empty response content")
case 1:
if candidate.Content.Parts[0].Text != "" {
log.Printf("%s\n\n", FormatContent(*candidate.Content))
return candidate.Content.Parts[0].Text, nil
}
return "", errors.New("part is not text")
Expand Down