From d5f38d869626e3729973d827fd001c5e0d0f2d1f Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Wed, 9 Oct 2024 16:28:28 -0500 Subject: [PATCH 01/20] Add lint workflow As copied from another internal repo. --- .github/workflows/lint.yml | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..0b340d5 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,33 @@ +name: "go-linter" + +on: + pull_request: + merge_group: + workflow_dispatch: + +jobs: + lint: + strategy: + fail-fast: false + runs-on: ubuntu-latest-xl + env: + GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPRIVATE: "" + GONOPROXY: "" + GONOSUMDB: github.com/github/* + steps: + - uses: actions/setup-go@v5 + with: + go-version: ${{ vars.GOVERSION }} + check-latest: true + - uses: actions/checkout@v4 + - name: Configure Go private module access + run: | + echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc + - name: Lint + # This also does checkout, setup-go, and proxy setup. + uses: github/go-linter@v1.2.1 + with: + strict: true + go-version: ${{ vars.GOVERSION }} + goproxy-token: ${{secrets.GOPROXY_TOKEN}} From 0a57639d82a3752d133b25dec2f6821366b1cb0f Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Wed, 9 Oct 2024 16:43:45 -0500 Subject: [PATCH 02/20] Add some docs, handle some errors, appease the linter --- cmd/list/list.go | 17 +++++++++++++---- cmd/root.go | 2 ++ cmd/run/run.go | 11 +++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 8ea9e11..dee59dd 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -1,3 +1,4 @@ +// list provides a gh command to list available models. package list import ( @@ -17,6 +18,7 @@ var ( lightGrayUnderline = ansi.ColorFunc("white+du") ) +// NewListCommand returns a new command to list available GitHub models. func NewListCommand() *cobra.Command { cmd := &cobra.Command{ Use: "list", @@ -28,7 +30,7 @@ func NewListCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + writeToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } @@ -47,9 +49,9 @@ func NewListCommand() *cobra.Command { isTTY := terminal.IsTerminalOutput() if isTTY { - io.WriteString(out, "\n") - io.WriteString(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - io.WriteString(out, "\n") + writeToOut(out, "\n") + writeToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) + writeToOut(out, "\n") } width, _, _ := terminal.Size() @@ -75,3 +77,10 @@ func NewListCommand() *cobra.Command { return cmd } + +func writeToOut(out io.Writer, message string) { + _, err := io.WriteString(out, message) + if err != nil { + fmt.Println("Error writing message:", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index ed72d29..94d344c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,3 +1,4 @@ +// cmd represents the base command when called without any subcommands. package cmd import ( @@ -6,6 +7,7 @@ import ( "github.com/spf13/cobra" ) +// NewRootCommand returns a new root command for the gh-models extension. func NewRootCommand() *cobra.Command { cmd := &cobra.Command{ Use: "gh models", diff --git a/cmd/run/run.go b/cmd/run/run.go index 54dd471..6f59b26 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -1,3 +1,4 @@ +// run provides a gh command to run a GitHub model. package run import ( @@ -20,12 +21,14 @@ import ( "github.com/spf13/pflag" ) +// ModelParameters represents the parameters that can be set for a model run. type ModelParameters struct { maxTokens *int temperature *float64 topP *float64 } +// FormatParameter returns a string representation of the parameter value. func (mp *ModelParameters) FormatParameter(name string) string { switch name { case "max-tokens": @@ -47,6 +50,7 @@ func (mp *ModelParameters) FormatParameter(name string) string { return "" } +// PopulateFromFlags populates the model parameters from the given flags. func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { maxTokensString, err := flags.GetString("max-tokens") if err != nil { @@ -87,6 +91,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { return nil } +// SetParameterByName sets the parameter with the given name to the given value. func (mp *ModelParameters) SetParameterByName(name string, value string) error { switch name { case "max-tokens": @@ -117,17 +122,20 @@ func (mp *ModelParameters) SetParameterByName(name string, value string) error { return nil } +// UpdateRequest updates the given request with the model parameters. func (mp *ModelParameters) UpdateRequest(req *azure_models.ChatCompletionOptions) { req.MaxTokens = mp.maxTokens req.Temperature = mp.temperature req.TopP = mp.topP } +// Conversation represents a conversation between the user and the model. type Conversation struct { messages []azure_models.ChatMessage systemPrompt string } +// AddMessage adds a message to the conversation. func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content string) { c.messages = append(c.messages, azure_models.ChatMessage{ Content: azure_models.Ptr(content), @@ -135,6 +143,7 @@ func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content str }) } +// GetMessages returns the messages in the conversation. func (c *Conversation) GetMessages() []azure_models.ChatMessage { length := len(c.messages) if c.systemPrompt != "" { @@ -159,6 +168,7 @@ func (c *Conversation) GetMessages() []azure_models.ChatMessage { return messages } +// Reset removes messages from the conversation. func (c *Conversation) Reset() { c.messages = nil } @@ -176,6 +186,7 @@ func isPipe(r io.Reader) bool { return false } +// NewRunCommand returns a new gh command for running a model. func NewRunCommand() *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", From c0acf3fd7decaef6b94a530475383d17ce654b36 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:02:12 -0500 Subject: [PATCH 03/20] Move WriteToOut to new util package --- cmd/list/list.go | 17 +++++------------ pkg/util/util.go | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 pkg/util/util.go diff --git a/cmd/list/list.go b/cmd/list/list.go index dee59dd..225ba54 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -3,13 +3,13 @@ package list import ( "fmt" - "io" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azure_models" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/mgutz/ansi" "github.com/spf13/cobra" ) @@ -30,7 +30,7 @@ func NewListCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - writeToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } @@ -49,9 +49,9 @@ func NewListCommand() *cobra.Command { isTTY := terminal.IsTerminalOutput() if isTTY { - writeToOut(out, "\n") - writeToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - writeToOut(out, "\n") + util.WriteToOut(out, "\n") + util.WriteToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) + util.WriteToOut(out, "\n") } width, _, _ := terminal.Size() @@ -77,10 +77,3 @@ func NewListCommand() *cobra.Command { return cmd } - -func writeToOut(out io.Writer, message string) { - _, err := io.WriteString(out, message) - if err != nil { - fmt.Println("Error writing message:", err) - } -} diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 0000000..1ae277c --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,14 @@ +package util + +import ( + "fmt" + "io" +) + +// WriteToOut writes a message to the given io.Writer. +func WriteToOut(out io.Writer, message string) { + _, err := io.WriteString(out, message) + if err != nil { + fmt.Println("Error writing message:", err) + } +} From b48725302e798cddf3a8f4c3f96028e98aabdceb Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:02:48 -0500 Subject: [PATCH 04/20] Fix errcheck linter errors > Error return value of `io.WriteString` is not checked --- cmd/run/run.go | 47 ++++++++++++++++++++++++----------------------- cmd/view/view.go | 4 ++-- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 93cd164..d154316 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -17,6 +17,7 @@ import ( "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azure_models" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -199,7 +200,7 @@ func NewRunCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } @@ -315,24 +316,24 @@ func NewRunCommand() *cobra.Command { } if prompt == "/parameters" { - io.WriteString(out, "Current parameters:\n") + util.WriteToOut(out, "Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - io.WriteString(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + util.WriteToOut(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - io.WriteString(out, "\n") - io.WriteString(out, "System Prompt:\n") + util.WriteToOut(out, "\n") + util.WriteToOut(out, "System Prompt:\n") if conversation.systemPrompt != "" { - io.WriteString(out, " "+conversation.systemPrompt+"\n") + util.WriteToOut(out, " "+conversation.systemPrompt+"\n") } else { - io.WriteString(out, " \n") + util.WriteToOut(out, " \n") } continue } if prompt == "/reset" || prompt == "/clear" { conversation.Reset() - io.WriteString(out, "Reset chat history\n") + util.WriteToOut(out, "Reset chat history\n") continue } @@ -344,35 +345,35 @@ func NewRunCommand() *cobra.Command { err := mp.SetParameterByName(name, value) if err != nil { - io.WriteString(out, err.Error()+"\n") + util.WriteToOut(out, err.Error()+"\n") continue } - io.WriteString(out, "Set "+name+" to "+value+"\n") + util.WriteToOut(out, "Set "+name+" to "+value+"\n") } else { - io.WriteString(out, "Invalid /set syntax. Usage: /set \n") + util.WriteToOut(out, "Invalid /set syntax. Usage: /set \n") } continue } if strings.HasPrefix(prompt, "/system-prompt ") { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - io.WriteString(out, "Updated system prompt\n") + util.WriteToOut(out, "Updated system prompt\n") continue } if prompt == "/help" { - io.WriteString(out, "Commands:\n") - io.WriteString(out, " /bye, /exit, /quit - Exit the chat\n") - io.WriteString(out, " /parameters - Show current model parameters\n") - io.WriteString(out, " /reset, /clear - Reset chat context\n") - io.WriteString(out, " /set - Set a model parameter\n") - io.WriteString(out, " /system-prompt - Set the system prompt\n") - io.WriteString(out, " /help - Show this help message\n") + util.WriteToOut(out, "Commands:\n") + util.WriteToOut(out, " /bye, /exit, /quit - Exit the chat\n") + util.WriteToOut(out, " /parameters - Show current model parameters\n") + util.WriteToOut(out, " /reset, /clear - Reset chat context\n") + util.WriteToOut(out, " /set - Set a model parameter\n") + util.WriteToOut(out, " /system-prompt - Set the system prompt\n") + util.WriteToOut(out, " /help - Show this help message\n") continue } - io.WriteString(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + util.WriteToOut(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") continue } @@ -415,11 +416,11 @@ func NewRunCommand() *cobra.Command { if choice.Delta != nil && choice.Delta.Content != nil { content := choice.Delta.Content messageBuilder.WriteString(*content) - io.WriteString(out, *content) + util.WriteToOut(out, *content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content messageBuilder.WriteString(*content) - io.WriteString(out, *content) + util.WriteToOut(out, *content) } // Introduce a small delay in between response tokens to better simulate a conversation @@ -429,7 +430,7 @@ func NewRunCommand() *cobra.Command { } } - io.WriteString(out, "\n") + util.WriteToOut(out, "\n") messageBuilder.WriteString("\n") conversation.AddMessage(azure_models.ChatMessageRoleAssistant, messageBuilder.String()) diff --git a/cmd/view/view.go b/cmd/view/view.go index f36a051..8e5b914 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -2,13 +2,13 @@ package view import ( "fmt" - "io" "github.com/AlecAivazis/survey/v2" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azure_models" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) @@ -22,7 +22,7 @@ func NewViewCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - io.WriteString(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } From 407365275ee30e014afe7b984f5dc20f438a39eb Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:15:01 -0500 Subject: [PATCH 05/20] Update readme with link to gh website --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b8dc4e5..138369c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Use the GitHub Models service from the CLI! ### Prerequisites -The extension requires the `gh` CLI to be installed and in the PATH. The extension also requires the user have authenticated via `gh auth`. +The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and in the `PATH`. The extension also requires the user have authenticated via `gh auth`. ### Installing @@ -73,4 +73,4 @@ git tag v0.0.x main git push origin tag v0.0.x ``` -This will trigger the `release` action that runs the actual production build. \ No newline at end of file +This will trigger the `release` action that runs the actual production build. From 7fbc049f0ddb0aa5dc0cccd15b9db1a93d891afe Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:24:30 -0500 Subject: [PATCH 06/20] More updates for the linter --- cmd/list/list.go | 2 +- cmd/root.go | 5 +++-- cmd/run/run.go | 20 ++++++++++---------- cmd/view/view.go | 2 ++ internal/azure_models/client.go | 12 ++++++++---- internal/azure_models/types.go | 33 +++++++++++++++++++-------------- pkg/util/util.go | 6 ++++++ 7 files changed, 49 insertions(+), 31 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 225ba54..d68b775 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -1,4 +1,4 @@ -// list provides a gh command to list available models. +// Package list provides a gh command to list available models. package list import ( diff --git a/cmd/root.go b/cmd/root.go index 6e65dec..3e8caf5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,12 +1,13 @@ -// cmd represents the base command when called without any subcommands. +// Package cmd represents the base command when called without any subcommands. package cmd import ( + "strings" + "github.com/github/gh-models/cmd/list" "github.com/github/gh-models/cmd/run" "github.com/github/gh-models/cmd/view" "github.com/spf13/cobra" - "strings" ) // NewRootCommand returns a new root command for the gh-models extension. diff --git a/cmd/run/run.go b/cmd/run/run.go index d154316..10ba1c9 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -1,4 +1,4 @@ -// run provides a gh command to run a GitHub model. +// Package run provides a gh command to run a GitHub model. package run import ( @@ -62,7 +62,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) } temperatureString, err := flags.GetString("temperature") @@ -74,7 +74,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) } topPString, err := flags.GetString("top-p") @@ -86,35 +86,35 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) } return nil } // SetParameterByName sets the parameter with the given name to the given value. -func (mp *ModelParameters) SetParameterByName(name string, value string) error { +func (mp *ModelParameters) SetParameterByName(name, value string) error { switch name { case "max-tokens": maxTokens, err := strconv.Atoi(value) if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) case "temperature": temperature, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) case "top-p": topP, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) default: return errors.New("unknown parameter '" + name + "'. Supported parameters: max-tokens, temperature, top-p") @@ -139,7 +139,7 @@ type Conversation struct { // AddMessage adds a message to the conversation. func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content string) { c.messages = append(c.messages, azure_models.ChatMessage{ - Content: azure_models.Ptr(content), + Content: util.Ptr(content), Role: role, }) } @@ -156,7 +156,7 @@ func (c *Conversation) GetMessages() []azure_models.ChatMessage { if c.systemPrompt != "" { messages[0] = azure_models.ChatMessage{ - Content: azure_models.Ptr(c.systemPrompt), + Content: util.Ptr(c.systemPrompt), Role: azure_models.ChatMessageRoleSystem, } startIndex++ diff --git a/cmd/view/view.go b/cmd/view/view.go index 8e5b914..aea4860 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -1,3 +1,4 @@ +// Package view provides a `gh models view` command to view details about a model. package view import ( @@ -12,6 +13,7 @@ import ( "github.com/spf13/cobra" ) +// NewViewCommand returns a new command to view details about a model. func NewViewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go index 996c270..2856e50 100644 --- a/internal/azure_models/client.go +++ b/internal/azure_models/client.go @@ -1,3 +1,4 @@ +// Package azure_models provides a client for interacting with the Azure models API. package azure_models import ( @@ -26,6 +27,7 @@ const ( prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" ) +// NewClient returns a new client using the given auth token. func NewClient(authToken string) *Client { httpClient, _ := api.DefaultHTTPClient() return &Client{ @@ -34,6 +36,7 @@ func NewClient(authToken string) *Client { } } +// GetChatCompletionStream returns a stream of chat completions for the given request. func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatCompletionResponse, error) { // Check if the model name is `o1-mini` or `o1-preview` if req.Model == "o1-mini" || req.Model == "o1-preview" { @@ -72,21 +75,22 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple if req.Stream { // Handle streamed response - chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) + chatCompletionResponse.Reader = sse.NewEventReader[chatCompletion](resp.Body) } else { - var completion ChatCompletion + var completion chatCompletion if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { return nil, err } // Create a mock reader that returns the decoded completion - mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) + mockReader := sse.NewMockEventReader([]chatCompletion{completion}) chatCompletionResponse.Reader = mockReader } return &chatCompletionResponse, nil } +// GetModelDetails returns the details of the specified model in a prticular registry. func (c *Client) GetModelDetails(registry string, modelName string, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) httpReq, err := http.NewRequest("GET", url, nil) @@ -162,6 +166,7 @@ func lowercaseStrings(input []string) []string { return output } +// ListModels returns a list of available models. func (c *Client) ListModels() ([]*ModelSummary, error) { body := bytes.NewReader([]byte(` { @@ -223,7 +228,6 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { } func (c *Client) handleHTTPError(resp *http.Response) error { - sb := strings.Builder{} switch resp.StatusCode { diff --git a/internal/azure_models/types.go b/internal/azure_models/types.go index de76846..6c033a2 100644 --- a/internal/azure_models/types.go +++ b/internal/azure_models/types.go @@ -8,19 +8,25 @@ import ( "github.com/github/gh-models/internal/sse" ) +// ChatMessageRole represents the role of a chat message. type ChatMessageRole string const ( + // ChatMessageRoleAssistant represents a message from the model. ChatMessageRoleAssistant ChatMessageRole = "assistant" - ChatMessageRoleSystem ChatMessageRole = "system" - ChatMessageRoleUser ChatMessageRole = "user" + // ChatMessageRoleSystem represents a system message. + ChatMessageRoleSystem ChatMessageRole = "system" + // ChatMessageRoleUser represents a message from the user. + ChatMessageRoleUser ChatMessageRole = "user" ) +// ChatMessage represents a message from a chat thread with a model. type ChatMessage struct { Content *string `json:"content,omitempty"` Role ChatMessageRole `json:"role"` } +// ChatCompletionOptions represents available options for a chat completion request. type ChatCompletionOptions struct { MaxTokens *int `json:"max_tokens,omitempty"` Messages []ChatMessage `json:"messages"` @@ -30,29 +36,30 @@ type ChatCompletionOptions struct { TopP *float64 `json:"top_p,omitempty"` } -type ChatChoiceMessage struct { +type chatChoiceMessage struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } -type ChatChoiceDelta struct { +type chatChoiceDelta struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } -type ChatChoice struct { - Delta *ChatChoiceDelta `json:"delta,omitempty"` +type chatChoice struct { + Delta *chatChoiceDelta `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` Index int32 `json:"index"` - Message *ChatChoiceMessage `json:"message,omitempty"` + Message *chatChoiceMessage `json:"message,omitempty"` } -type ChatCompletion struct { - Choices []ChatChoice `json:"choices"` +type chatCompletion struct { + Choices []chatChoice `json:"choices"` } +// ChatCompletionResponse represents a response to a chat completion request. type ChatCompletionResponse struct { - Reader sse.Reader[ChatCompletion] + Reader sse.Reader[chatCompletion] } type modelCatalogSearchResponse struct { @@ -82,6 +89,7 @@ type ModelSummary struct { RegistryName string `json:"registry_name"` } +// HasName checks if the model has the given name. func (m *ModelSummary) HasName(name string) bool { return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name) } @@ -134,10 +142,7 @@ type ModelDetails struct { RateLimitTier string `json:"rateLimitTier"` } +// ContextLimits returns a summary of the context limits for the model. func (m *ModelDetails) ContextLimits() string { return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens) } - -func Ptr[T any](value T) *T { - return &value -} diff --git a/pkg/util/util.go b/pkg/util/util.go index 1ae277c..1856f20 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -1,3 +1,4 @@ +// Package util provides utility functions for the gh-models extension. package util import ( @@ -12,3 +13,8 @@ func WriteToOut(out io.Writer, message string) { fmt.Println("Error writing message:", err) } } + +// Ptr returns a pointer to the given value. +func Ptr[T any](value T) *T { + return &value +} From 2245c3ff05ab7134a3a617be89dde46e43ad3abf Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:25:35 -0500 Subject: [PATCH 07/20] Explicitly list permissions in linter workflow --- .github/workflows/lint.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0b340d5..cc5ebdc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,6 +5,9 @@ on: merge_group: workflow_dispatch: +permissions: + contents: read + jobs: lint: strategy: From 22b2ae750dca3d003a1dcb622ea98ad51600717a Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:29:49 -0500 Subject: [PATCH 08/20] Fix unhandled-error linter issues --- cmd/run/run.go | 10 ++++++++-- internal/azure_models/client.go | 33 +++++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 10ba1c9..26fb710 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -415,11 +415,17 @@ func NewRunCommand() *cobra.Command { // non-streamed responses use `.Message`, so let's support both if choice.Delta != nil && choice.Delta.Content != nil { content := choice.Delta.Content - messageBuilder.WriteString(*content) + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } util.WriteToOut(out, *content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content - messageBuilder.WriteString(*content) + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } util.WriteToOut(out, *content) } diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go index 2856e50..1657428 100644 --- a/internal/azure_models/client.go +++ b/internal/azure_models/client.go @@ -229,23 +229,44 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { func (c *Client) handleHTTPError(resp *http.Response) error { sb := strings.Builder{} + var err error switch resp.StatusCode { case http.StatusUnauthorized: - sb.WriteString("unauthorized") + _, err = sb.WriteString("unauthorized") + if err != nil { + return err + } case http.StatusBadRequest: - sb.WriteString("bad request") + _, err = sb.WriteString("bad request") + if err != nil { + return err + } default: - sb.WriteString("unexpected response from the server: " + resp.Status) + _, err = sb.WriteString("unexpected response from the server: " + resp.Status) + if err != nil { + return err + } } body, _ := io.ReadAll(resp.Body) if len(body) > 0 { - sb.WriteString("\n") - sb.Write(body) - sb.WriteString("\n") + _, err = sb.WriteString("\n") + if err != nil { + return err + } + + _, err = sb.Write(body) + if err != nil { + return err + } + + _, err = sb.WriteString("\n") + if err != nil { + return err + } } return errors.New(sb.String()) From c7591b781821116f7749a818dbf9e999122bf31f Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:30:41 -0500 Subject: [PATCH 09/20] Fix paramTypeCombine linter issues --- cmd/view/model_printer.go | 4 ++-- internal/azure_models/client.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index d3f20b7..571db71 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -59,7 +59,7 @@ func (p *modelPrinter) render() error { return nil } -func (p *modelPrinter) printLabelledLine(label string, value string) { +func (p *modelPrinter) printLabelledLine(label, value string) { if value == "" { return } @@ -76,7 +76,7 @@ func (p *modelPrinter) printLabelledMultiLineList(label string, values []string) p.printMultipleLinesWithLabel(label, strings.Join(values, ", ")) } -func (p *modelPrinter) printMultipleLinesWithLabel(label string, value string) { +func (p *modelPrinter) printMultipleLinesWithLabel(label, value string) { if value == "" { return } diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go index 1657428..42ac0c3 100644 --- a/internal/azure_models/client.go +++ b/internal/azure_models/client.go @@ -91,7 +91,7 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple } // GetModelDetails returns the details of the specified model in a prticular registry. -func (c *Client) GetModelDetails(registry string, modelName string, version string) (*ModelDetails, error) { +func (c *Client) GetModelDetails(registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) httpReq, err := http.NewRequest("GET", url, nil) if err != nil { From 6ad6204cf6ce10ba90d1d0edcf426e9d1f571bf4 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:32:13 -0500 Subject: [PATCH 10/20] Rename package to azuremodels for linter Error: internal/azure_models/client.go:2:9: var-naming: don't use an underscore in package name (revive) package azure_models ^ --- cmd/list/list.go | 4 +-- cmd/run/run.go | 26 +++++++++---------- cmd/view/model_printer.go | 8 +++--- cmd/view/view.go | 6 ++--- .../{azure_models => azuremodels}/client.go | 4 +-- .../{azure_models => azuremodels}/types.go | 2 +- internal/ux/filtering.go | 10 +++---- internal/ux/sorting.go | 4 +-- 8 files changed, 31 insertions(+), 33 deletions(-) rename internal/{azure_models => azuremodels}/client.go (98%) rename internal/{azure_models => azuremodels}/types.go (99%) diff --git a/cmd/list/list.go b/cmd/list/list.go index d68b775..2d8dd09 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -7,7 +7,7 @@ import ( "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/util" "github.com/mgutz/ansi" @@ -34,7 +34,7 @@ func NewListCommand() *cobra.Command { return nil } - client := azure_models.NewClient(token) + client := azuremodels.NewClient(token) models, err := client.ListModels() if err != nil { diff --git a/cmd/run/run.go b/cmd/run/run.go index 26fb710..aaf64a2 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -15,7 +15,7 @@ import ( "github.com/briandowns/spinner" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" @@ -124,7 +124,7 @@ func (mp *ModelParameters) SetParameterByName(name, value string) error { } // UpdateRequest updates the given request with the model parameters. -func (mp *ModelParameters) UpdateRequest(req *azure_models.ChatCompletionOptions) { +func (mp *ModelParameters) UpdateRequest(req *azuremodels.ChatCompletionOptions) { req.MaxTokens = mp.maxTokens req.Temperature = mp.temperature req.TopP = mp.topP @@ -132,32 +132,32 @@ func (mp *ModelParameters) UpdateRequest(req *azure_models.ChatCompletionOptions // Conversation represents a conversation between the user and the model. type Conversation struct { - messages []azure_models.ChatMessage + messages []azuremodels.ChatMessage systemPrompt string } // AddMessage adds a message to the conversation. -func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content string) { - c.messages = append(c.messages, azure_models.ChatMessage{ +func (c *Conversation) AddMessage(role azuremodels.ChatMessageRole, content string) { + c.messages = append(c.messages, azuremodels.ChatMessage{ Content: util.Ptr(content), Role: role, }) } // GetMessages returns the messages in the conversation. -func (c *Conversation) GetMessages() []azure_models.ChatMessage { +func (c *Conversation) GetMessages() []azuremodels.ChatMessage { length := len(c.messages) if c.systemPrompt != "" { length++ } - messages := make([]azure_models.ChatMessage, length) + messages := make([]azuremodels.ChatMessage, length) startIndex := 0 if c.systemPrompt != "" { - messages[0] = azure_models.ChatMessage{ + messages[0] = azuremodels.ChatMessage{ Content: util.Ptr(c.systemPrompt), - Role: azure_models.ChatMessageRoleSystem, + Role: azuremodels.ChatMessageRoleSystem, } startIndex++ } @@ -204,7 +204,7 @@ func NewRunCommand() *cobra.Command { return nil } - client := azure_models.NewClient(token) + client := azuremodels.NewClient(token) models, err := client.ListModels() if err != nil { @@ -377,9 +377,9 @@ func NewRunCommand() *cobra.Command { continue } - conversation.AddMessage(azure_models.ChatMessageRoleUser, prompt) + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) - req := azure_models.ChatCompletionOptions{ + req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, } @@ -439,7 +439,7 @@ func NewRunCommand() *cobra.Command { util.WriteToOut(out, "\n") messageBuilder.WriteString("\n") - conversation.AddMessage(azure_models.ChatMessageRoleAssistant, messageBuilder.String()) + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) if singleShot { break diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index 571db71..63790f3 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -6,7 +6,7 @@ import ( "github.com/cli/cli/v2/pkg/markdown" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/mgutz/ansi" ) @@ -15,13 +15,13 @@ var ( ) type modelPrinter struct { - modelSummary *azure_models.ModelSummary - modelDetails *azure_models.ModelDetails + modelSummary *azuremodels.ModelSummary + modelDetails *azuremodels.ModelDetails printer tableprinter.TablePrinter terminalWidth int } -func newModelPrinter(summary *azure_models.ModelSummary, details *azure_models.ModelDetails, terminal term.Term) modelPrinter { +func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, terminal term.Term) modelPrinter { width, _, _ := terminal.Size() printer := tableprinter.New(terminal.Out(), terminal.IsTerminalOutput(), width) return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: width} diff --git a/cmd/view/view.go b/cmd/view/view.go index aea4860..6916ee1 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -7,7 +7,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" @@ -28,7 +28,7 @@ func NewViewCommand() *cobra.Command { return nil } - client := azure_models.NewClient(token) + client := azuremodels.NewClient(token) models, err := client.ListModels() if err != nil { @@ -86,7 +86,7 @@ func NewViewCommand() *cobra.Command { } // getModelByName returns the model with the specified name, or an error if no such model exists within the given list. -func getModelByName(modelName string, models []*azure_models.ModelSummary) (*azure_models.ModelSummary, error) { +func getModelByName(modelName string, models []*azuremodels.ModelSummary) (*azuremodels.ModelSummary, error) { for _, model := range models { if model.HasName(modelName) { return model, nil diff --git a/internal/azure_models/client.go b/internal/azuremodels/client.go similarity index 98% rename from internal/azure_models/client.go rename to internal/azuremodels/client.go index 42ac0c3..fb3bb90 100644 --- a/internal/azure_models/client.go +++ b/internal/azuremodels/client.go @@ -1,5 +1,5 @@ -// Package azure_models provides a client for interacting with the Azure models API. -package azure_models +// Package azuremodels provides a client for interacting with the Azure models API. +package azuremodels import ( "bytes" diff --git a/internal/azure_models/types.go b/internal/azuremodels/types.go similarity index 99% rename from internal/azure_models/types.go rename to internal/azuremodels/types.go index 6c033a2..ef8c889 100644 --- a/internal/azure_models/types.go +++ b/internal/azuremodels/types.go @@ -1,4 +1,4 @@ -package azure_models +package azuremodels import ( "encoding/json" diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go index 89dcc17..2376574 100644 --- a/internal/ux/filtering.go +++ b/internal/ux/filtering.go @@ -1,15 +1,13 @@ package ux -import ( - "github.com/github/gh-models/internal/azure_models" -) +import "github.com/github/gh-models/internal/azuremodels" -func IsChatModel(model *azure_models.ModelSummary) bool { +func IsChatModel(model *azuremodels.ModelSummary) bool { return model.Task == "chat-completion" } -func FilterToChatModels(models []*azure_models.ModelSummary) []*azure_models.ModelSummary { - var chatModels []*azure_models.ModelSummary +func FilterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { + var chatModels []*azuremodels.ModelSummary for _, model := range models { if IsChatModel(model) { chatModels = append(chatModels, model) diff --git a/internal/ux/sorting.go b/internal/ux/sorting.go index c8c66d6..be1495c 100644 --- a/internal/ux/sorting.go +++ b/internal/ux/sorting.go @@ -5,14 +5,14 @@ import ( "sort" "strings" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" ) var ( featuredModelNames = []string{} ) -func SortModels(models []*azure_models.ModelSummary) { +func SortModels(models []*azuremodels.ModelSummary) { sort.Slice(models, func(i, j int) bool { // Sort featured models first, by name isFeaturedI := slices.Contains(featuredModelNames, models[i].Name) From 2920dee5149904ea0362c5dce29b67f5fc741175 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:47:20 -0500 Subject: [PATCH 11/20] More updates for linter --- cmd/list/list.go | 12 +++++++++++- cmd/run/run.go | 5 ++++- internal/azuremodels/client.go | 7 ++++++- internal/azuremodels/types.go | 2 ++ internal/sse/eventreader.go | 1 + internal/ux/filtering.go | 11 +---------- internal/ux/sorting.go | 19 +++++++++++-------- main.go | 5 +++-- 8 files changed, 39 insertions(+), 23 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 2d8dd09..c001f3f 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -43,7 +43,7 @@ func NewListCommand() *cobra.Command { // For now, filter to just chat models. // Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column. - models = ux.FilterToChatModels(models) + models = filterToChatModels(models) ux.SortModels(models) isTTY := terminal.IsTerminalOutput() @@ -77,3 +77,13 @@ func NewListCommand() *cobra.Command { return cmd } + +func filterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { + var chatModels []*azuremodels.ModelSummary + for _, model := range models { + if ux.IsChatModel(model) { + chatModels = append(chatModels, model) + } + } + return chatModels +} diff --git a/cmd/run/run.go b/cmd/run/run.go index aaf64a2..b8d991a 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -437,7 +437,10 @@ func NewRunCommand() *cobra.Command { } util.WriteToOut(out, "\n") - messageBuilder.WriteString("\n") + _, err = messageBuilder.WriteString("\n") + if err != nil { + return err + } conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index fb3bb90..3fa2e2b 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -16,6 +16,7 @@ import ( "golang.org/x/text/language/display" ) +// Client provides a client for interacting with the Azure models API. type Client struct { client *http.Client token string @@ -93,7 +94,7 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple // GetModelDetails returns the details of the specified model in a prticular registry. func (c *Client) GetModelDetails(registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) - httpReq, err := http.NewRequest("GET", url, nil) + httpReq, err := http.NewRequest("GET", url, http.NoBody) if err != nil { return nil, err } @@ -105,6 +106,8 @@ func (c *Client) GetModelDetails(registry, modelName, version string) (*ModelDet return nil, err } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, c.handleHTTPError(resp) } @@ -192,6 +195,8 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { return nil, err } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, c.handleHTTPError(resp) } diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index ef8c889..c290d7e 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -78,6 +78,7 @@ type modelCatalogSearchSummary struct { Summary string `json:"summary"` } +// ModelSummary includes basic information about a model. type ModelSummary struct { ID string `json:"id"` Name string `json:"name"` @@ -127,6 +128,7 @@ type modelCatalogDetailsResponse struct { } `json:"modelLimits"` } +// ModelDetails includes detailed information about a model. type ModelDetails struct { Description string `json:"description"` Evaluation string `json:"evaluation"` diff --git a/internal/sse/eventreader.go b/internal/sse/eventreader.go index 5eddcc8..391c3e2 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/eventreader.go @@ -1,5 +1,6 @@ // Forked from https://github.com/Azure/azure-sdk-for-go/blob/4661007ca1fd68b2e31f3502d4282904014fd731/sdk/ai/azopenai/event_reader.go#L18 +// Package sse provides a reader for Server-Sent Events (SSE) streams. package sse import ( diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go index 2376574..b645ac2 100644 --- a/internal/ux/filtering.go +++ b/internal/ux/filtering.go @@ -2,16 +2,7 @@ package ux import "github.com/github/gh-models/internal/azuremodels" +// IsChatModel returns true if the given model is for chat completions. func IsChatModel(model *azuremodels.ModelSummary) bool { return model.Task == "chat-completion" } - -func FilterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { - var chatModels []*azuremodels.ModelSummary - for _, model := range models { - if IsChatModel(model) { - chatModels = append(chatModels, model) - } - } - return chatModels -} diff --git a/internal/ux/sorting.go b/internal/ux/sorting.go index be1495c..59b0e4e 100644 --- a/internal/ux/sorting.go +++ b/internal/ux/sorting.go @@ -12,6 +12,7 @@ var ( featuredModelNames = []string{} ) +// SortModels sorts the given models in place, with featured models first, and then by friendly name. func SortModels(models []*azuremodels.ModelSummary) { sort.Slice(models, func(i, j int) bool { // Sort featured models first, by name @@ -20,15 +21,17 @@ func SortModels(models []*azuremodels.ModelSummary) { if isFeaturedI && !isFeaturedJ { return true - } else if !isFeaturedI && isFeaturedJ { - return false - } else { - // Otherwise, sort by friendly name - // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. - friendlyNameI := strings.ToLower(models[i].FriendlyName) - friendlyNameJ := strings.ToLower(models[j].FriendlyName) + } - return friendlyNameI < friendlyNameJ + if !isFeaturedI && isFeaturedJ { + return false } + + // Otherwise, sort by friendly name + // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. + friendlyNameI := strings.ToLower(models[i].FriendlyName) + friendlyNameJ := strings.ToLower(models[j].FriendlyName) + + return friendlyNameI < friendlyNameJ }) } diff --git a/main.go b/main.go index 23f6148..6aec694 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// Package main provides the entry point for the gh-models extension. package main import ( @@ -20,12 +21,12 @@ func main() { } func mainRun() exitCode { - cmd := cmd.NewRootCommand() + rootCmd := cmd.NewRootCommand() exitCode := exitOK ctx := context.Background() - if _, err := cmd.ExecuteContextC(ctx); err != nil { + if _, err := rootCmd.ExecuteContextC(ctx); err != nil { exitCode = exitError } From 312c54813e795496e261bdb81540a54dc0277fcb Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:55:17 -0500 Subject: [PATCH 12/20] Add context parameter to pass along in requests Linter errors like: Error: internal/azuremodels/client.go:56:33: should rewrite http.NewRequestWithContext or add (*Request).WithContext (noctx) httpReq, err := http.NewRequest("POST", prodInferenceURL, body) ^ --- cmd/list/list.go | 3 ++- cmd/run/run.go | 5 +++-- cmd/view/view.go | 5 +++-- internal/azuremodels/client.go | 13 +++++++------ 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index c001f3f..6ab3088 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -35,8 +35,9 @@ func NewListCommand() *cobra.Command { } client := azuremodels.NewClient(token) + ctx := cmd.Context() - models, err := client.ListModels() + models, err := client.ListModels(ctx) if err != nil { return err } diff --git a/cmd/run/run.go b/cmd/run/run.go index b8d991a..4acc6c6 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -205,8 +205,9 @@ func NewRunCommand() *cobra.Command { } client := azuremodels.NewClient(token) + ctx := cmd.Context() - models, err := client.ListModels() + models, err := client.ListModels(ctx) if err != nil { return err } @@ -390,7 +391,7 @@ func NewRunCommand() *cobra.Command { sp.Start() defer sp.Stop() - resp, err := client.GetChatCompletionStream(req) + resp, err := client.GetChatCompletionStream(ctx, req) if err != nil { return err } diff --git a/cmd/view/view.go b/cmd/view/view.go index 6916ee1..777281d 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -29,8 +29,9 @@ func NewViewCommand() *cobra.Command { } client := azuremodels.NewClient(token) + ctx := cmd.Context() - models, err := client.ListModels() + models, err := client.ListModels(ctx) if err != nil { return err } @@ -67,7 +68,7 @@ func NewViewCommand() *cobra.Command { return err } - modelDetails, err := client.GetModelDetails(modelSummary.RegistryName, modelSummary.Name, modelSummary.Version) + modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version) if err != nil { return err } diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index 3fa2e2b..cd0b9d5 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -3,6 +3,7 @@ package azuremodels import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -38,7 +39,7 @@ func NewClient(authToken string) *Client { } // GetChatCompletionStream returns a stream of chat completions for the given request. -func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatCompletionResponse, error) { +func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { // Check if the model name is `o1-mini` or `o1-preview` if req.Model == "o1-mini" || req.Model == "o1-preview" { req.Stream = false @@ -53,7 +54,7 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple body := bytes.NewReader(bodyBytes) - httpReq, err := http.NewRequest("POST", prodInferenceURL, body) + httpReq, err := http.NewRequestWithContext(ctx, "POST", prodInferenceURL, body) if err != nil { return nil, err } @@ -92,9 +93,9 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple } // GetModelDetails returns the details of the specified model in a prticular registry. -func (c *Client) GetModelDetails(registry, modelName, version string) (*ModelDetails, error) { +func (c *Client) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) - httpReq, err := http.NewRequest("GET", url, http.NoBody) + httpReq, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) if err != nil { return nil, err } @@ -170,7 +171,7 @@ func lowercaseStrings(input []string) []string { } // ListModels returns a list of available models. -func (c *Client) ListModels() ([]*ModelSummary, error) { +func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { body := bytes.NewReader([]byte(` { "filters": [ @@ -183,7 +184,7 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { } `)) - httpReq, err := http.NewRequest("POST", prodModelsURL, body) + httpReq, err := http.NewRequestWithContext(ctx, "POST", prodModelsURL, body) if err != nil { return nil, err } From b0832d24113265f751e14b8c0f122e8b5654d1f9 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 13:57:17 -0500 Subject: [PATCH 13/20] Docs updates for linter --- internal/sse/mockeventreader.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/sse/mockeventreader.go b/internal/sse/mockeventreader.go index aa015a7..b8b4a03 100644 --- a/internal/sse/mockeventreader.go +++ b/internal/sse/mockeventreader.go @@ -7,7 +7,7 @@ import ( ) // MockEventReader is a mock implementation of the sse.EventReader. This lets us use EventReader as a common interface -// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models) +// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models). type MockEventReader[T any] struct { reader io.ReadCloser scanner *bufio.Scanner @@ -15,6 +15,7 @@ type MockEventReader[T any] struct { index int } +// NewMockEventReader creates a new MockEventReader with the given events. func NewMockEventReader[T any](events []T) *MockEventReader[T] { data := []byte{} reader := io.NopCloser(bytes.NewReader(data)) @@ -22,6 +23,7 @@ func NewMockEventReader[T any](events []T) *MockEventReader[T] { return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} } +// Read reads the next event from the stream. func (mer *MockEventReader[T]) Read() (T, error) { if mer.index >= len(mer.events) { var zero T @@ -32,6 +34,7 @@ func (mer *MockEventReader[T]) Read() (T, error) { return event, nil } +// Close closes the Reader and any applicable inner stream state. func (mer *MockEventReader[T]) Close() error { return mer.reader.Close() } From 88d20471483b7ff2c3080ad0726ddb1cc30af12a Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 15:49:15 -0500 Subject: [PATCH 14/20] Split up run func a bit for linter Error: cmd/run/run.go:191:1: Function name: NewRunCommand, Cyclomatic Complexity: 50, Halstead Volume: 6931.94, Maintainability Index: 13 (maintidx) func NewRunCommand() *cobra.Command { ^ --- cmd/run/run.go | 251 ++++++++++++++++++++-------------- internal/azuremodels/types.go | 5 +- 2 files changed, 152 insertions(+), 104 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 4acc6c6..b4f3ab0 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -214,48 +214,9 @@ func NewRunCommand() *cobra.Command { ux.SortModels(models) - modelName := "" - switch { - case len(args) == 0: - // Need to prompt for a model - prompt := &survey.Select{ - Message: "Select a model:", - Options: []string{}, - } - - for _, model := range models { - if !ux.IsChatModel(model) { - continue - } - prompt.Options = append(prompt.Options, model.FriendlyName) - } - - err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) - if err != nil { - return err - } - - case len(args) >= 1: - modelName = args[0] - } - - noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." - - if modelName == "" { - return errors.New(noMatchErrorMessage) - } - - foundMatch := false - for _, model := range models { - if model.HasName(modelName) { - modelName = model.Name - foundMatch = true - break - } - } - - if !foundMatch { - return errors.New(noMatchErrorMessage) + modelName, err := getModelNameFromArgs(args, models) + if err != nil { + return err } initialPrompt := "" @@ -317,64 +278,31 @@ func NewRunCommand() *cobra.Command { } if prompt == "/parameters" { - util.WriteToOut(out, "Current parameters:\n") - names := []string{"max-tokens", "temperature", "top-p"} - for _, name := range names { - util.WriteToOut(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) - } - util.WriteToOut(out, "\n") - util.WriteToOut(out, "System Prompt:\n") - if conversation.systemPrompt != "" { - util.WriteToOut(out, " "+conversation.systemPrompt+"\n") - } else { - util.WriteToOut(out, " \n") - } + handleParametersPrompt(out, conversation, mp) continue } if prompt == "/reset" || prompt == "/clear" { - conversation.Reset() - util.WriteToOut(out, "Reset chat history\n") + handleResetPrompt(out, conversation) continue } if strings.HasPrefix(prompt, "/set ") { - parts := strings.Split(prompt, " ") - if len(parts) == 3 { - name := parts[1] - value := parts[2] - - err := mp.SetParameterByName(name, value) - if err != nil { - util.WriteToOut(out, err.Error()+"\n") - continue - } - - util.WriteToOut(out, "Set "+name+" to "+value+"\n") - } else { - util.WriteToOut(out, "Invalid /set syntax. Usage: /set \n") - } + handleSetPrompt(out, prompt, mp) continue } if strings.HasPrefix(prompt, "/system-prompt ") { - conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - util.WriteToOut(out, "Updated system prompt\n") + handleSystemPrompt(out, prompt, conversation) continue } if prompt == "/help" { - util.WriteToOut(out, "Commands:\n") - util.WriteToOut(out, " /bye, /exit, /quit - Exit the chat\n") - util.WriteToOut(out, " /parameters - Show current model parameters\n") - util.WriteToOut(out, " /reset, /clear - Reset chat context\n") - util.WriteToOut(out, " /set - Set a model parameter\n") - util.WriteToOut(out, " /system-prompt - Set the system prompt\n") - util.WriteToOut(out, " /help - Show this help message\n") + handleHelpPrompt(out) continue } - util.WriteToOut(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + handleUnrecognizedPrompt(out, prompt) continue } @@ -412,27 +340,9 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { - // Streamed responses from the OpenAI API have their data in `.Delta`, while - // non-streamed responses use `.Message`, so let's support both - if choice.Delta != nil && choice.Delta.Content != nil { - content := choice.Delta.Content - _, err := messageBuilder.WriteString(*content) - if err != nil { - return err - } - util.WriteToOut(out, *content) - } else if choice.Message != nil && choice.Message.Content != nil { - content := choice.Message.Content - _, err := messageBuilder.WriteString(*content) - if err != nil { - return err - } - util.WriteToOut(out, *content) - } - - // Introduce a small delay in between response tokens to better simulate a conversation - if terminal.IsTerminalOutput() { - time.Sleep(10 * time.Millisecond) + err = handleCompletionChoice(choice, messageBuilder, out, terminal) + if err != nil { + return err } } } @@ -461,3 +371,140 @@ func NewRunCommand() *cobra.Command { return cmd } + +func getModelNameFromArgs(args []string, models []*azuremodels.ModelSummary) (string, error) { + modelName := "" + + switch { + case len(args) == 0: + // Need to prompt for a model + prompt := &survey.Select{ + Message: "Select a model:", + Options: []string{}, + } + + for _, model := range models { + if !ux.IsChatModel(model) { + continue + } + prompt.Options = append(prompt.Options, model.FriendlyName) + } + + err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) + if err != nil { + return "", err + } + + case len(args) >= 1: + modelName = args[0] + } + + return validateModelName(modelName, models) +} + +func validateModelName(modelName string, models []*azuremodels.ModelSummary) (string, error) { + noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." + + if modelName == "" { + return "", errors.New(noMatchErrorMessage) + } + + foundMatch := false + for _, model := range models { + if model.HasName(modelName) { + modelName = model.Name + foundMatch = true + break + } + } + + if !foundMatch { + return "", errors.New(noMatchErrorMessage) + } + + return modelName, nil +} + +func handleParametersPrompt(out io.Writer, conversation Conversation, mp ModelParameters) { + util.WriteToOut(out, "Current parameters:\n") + names := []string{"max-tokens", "temperature", "top-p"} + for _, name := range names { + util.WriteToOut(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + } + util.WriteToOut(out, "\n") + util.WriteToOut(out, "System Prompt:\n") + if conversation.systemPrompt != "" { + util.WriteToOut(out, " "+conversation.systemPrompt+"\n") + } else { + util.WriteToOut(out, " \n") + } +} + +func handleResetPrompt(out io.Writer, conversation Conversation) { + conversation.Reset() + util.WriteToOut(out, "Reset chat history\n") +} + +func handleSetPrompt(out io.Writer, prompt string, mp ModelParameters) { + parts := strings.Split(prompt, " ") + if len(parts) == 3 { + name := parts[1] + value := parts[2] + + err := mp.SetParameterByName(name, value) + if err != nil { + util.WriteToOut(out, err.Error()+"\n") + return + } + + util.WriteToOut(out, "Set "+name+" to "+value+"\n") + } else { + util.WriteToOut(out, "Invalid /set syntax. Usage: /set \n") + } +} + +func handleSystemPrompt(out io.Writer, prompt string, conversation Conversation) { + conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") + util.WriteToOut(out, "Updated system prompt\n") +} + +func handleHelpPrompt(out io.Writer) { + util.WriteToOut(out, "Commands:\n") + util.WriteToOut(out, " /bye, /exit, /quit - Exit the chat\n") + util.WriteToOut(out, " /parameters - Show current model parameters\n") + util.WriteToOut(out, " /reset, /clear - Reset chat context\n") + util.WriteToOut(out, " /set - Set a model parameter\n") + util.WriteToOut(out, " /system-prompt - Set the system prompt\n") + util.WriteToOut(out, " /help - Show this help message\n") +} + +func handleUnrecognizedPrompt(out io.Writer, prompt string) { + util.WriteToOut(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") +} + +func handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder, out io.Writer, terminal term.Term) error { + // Streamed responses from the OpenAI API have their data in `.Delta`, while + // non-streamed responses use `.Message`, so let's support both + if choice.Delta != nil && choice.Delta.Content != nil { + content := choice.Delta.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + util.WriteToOut(out, *content) + } else if choice.Message != nil && choice.Message.Content != nil { + content := choice.Message.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + util.WriteToOut(out, *content) + } + + // Introduce a small delay in between response tokens to better simulate a conversation + if terminal.IsTerminalOutput() { + time.Sleep(10 * time.Millisecond) + } + + return nil +} diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index c290d7e..1070479 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -46,7 +46,8 @@ type chatChoiceDelta struct { Role *string `json:"role,omitempty"` } -type chatChoice struct { +// ChatChoice represents a choice in a chat completion. +type ChatChoice struct { Delta *chatChoiceDelta `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` Index int32 `json:"index"` @@ -54,7 +55,7 @@ type chatChoice struct { } type chatCompletion struct { - Choices []chatChoice `json:"choices"` + Choices []ChatChoice `json:"choices"` } // ChatCompletionResponse represents a response to a chat completion request. From 5baeb6b36005ba7bae9a49e0349dcd0ea505a775 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:05:00 -0500 Subject: [PATCH 15/20] More linter fixes --- internal/azuremodels/client.go | 12 ++++++------ internal/azuremodels/types.go | 5 +++-- internal/ux/filtering.go | 1 + 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index cd0b9d5..a4b60d3 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -54,7 +54,7 @@ func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletion body := bytes.NewReader(bodyBytes) - httpReq, err := http.NewRequestWithContext(ctx, "POST", prodInferenceURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) if err != nil { return nil, err } @@ -77,15 +77,15 @@ func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletion if req.Stream { // Handle streamed response - chatCompletionResponse.Reader = sse.NewEventReader[chatCompletion](resp.Body) + chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) } else { - var completion chatCompletion + var completion ChatCompletion if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { return nil, err } // Create a mock reader that returns the decoded completion - mockReader := sse.NewMockEventReader([]chatCompletion{completion}) + mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) chatCompletionResponse.Reader = mockReader } @@ -95,7 +95,7 @@ func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletion // GetModelDetails returns the details of the specified model in a prticular registry. func (c *Client) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) - httpReq, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, err } @@ -184,7 +184,7 @@ func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { } `)) - httpReq, err := http.NewRequestWithContext(ctx, "POST", prodModelsURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) if err != nil { return nil, err } diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index 1070479..98138fa 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -54,13 +54,14 @@ type ChatChoice struct { Message *chatChoiceMessage `json:"message,omitempty"` } -type chatCompletion struct { +// ChatCompletion represents a chat completion. +type ChatCompletion struct { Choices []ChatChoice `json:"choices"` } // ChatCompletionResponse represents a response to a chat completion request. type ChatCompletionResponse struct { - Reader sse.Reader[chatCompletion] + Reader sse.Reader[ChatCompletion] } type modelCatalogSearchResponse struct { diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go index b645ac2..f456c85 100644 --- a/internal/ux/filtering.go +++ b/internal/ux/filtering.go @@ -1,3 +1,4 @@ +// Package ux provides utility functions around presentation and user experience. package ux import "github.com/github/gh-models/internal/azuremodels" From aeaa668432e3953dfd91f1a8273f2f322f29660a Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:05:31 -0500 Subject: [PATCH 16/20] Pull out runCommandHandler type Working toward fixing linter issues like these: Error: cmd/run/run.go:320:5: deferInLoop: Possible resource leak, 'defer' is called in the 'for' loop (gocritic) defer sp.Stop() ^ Error: cmd/run/run.go:327:5: deferInLoop: Possible resource leak, 'defer' is called in the 'for' loop (gocritic) defer resp.Reader.Close() ^ --- cmd/run/run.go | 159 +++++++++++++++++++++++++++++-------------------- 1 file changed, 96 insertions(+), 63 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index b4f3ab0..abc6cae 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -3,6 +3,7 @@ package run import ( "bufio" + "context" "errors" "fmt" "io" @@ -16,6 +17,7 @@ import ( "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" @@ -194,27 +196,17 @@ func NewRunCommand() *cobra.Command { Short: "Run inference with the specified model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - errOut := terminal.ErrOut() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + cmdHandler := newRunCommandHandler(cmd, args) + if cmdHandler == nil { return nil } - client := azuremodels.NewClient(token) - ctx := cmd.Context() - - models, err := client.ListModels(ctx) + models, err := cmdHandler.loadModels() if err != nil { return err } - ux.SortModels(models) - - modelName, err := getModelNameFromArgs(args, models) + modelName, err := cmdHandler.getModelNameFromArgs(models) if err != nil { return err } @@ -278,31 +270,31 @@ func NewRunCommand() *cobra.Command { } if prompt == "/parameters" { - handleParametersPrompt(out, conversation, mp) + cmdHandler.handleParametersPrompt(conversation, mp) continue } if prompt == "/reset" || prompt == "/clear" { - handleResetPrompt(out, conversation) + cmdHandler.handleResetPrompt(conversation) continue } if strings.HasPrefix(prompt, "/set ") { - handleSetPrompt(out, prompt, mp) + cmdHandler.handleSetPrompt(prompt, mp) continue } if strings.HasPrefix(prompt, "/system-prompt ") { - handleSystemPrompt(out, prompt, conversation) + cmdHandler.handleSystemPrompt(prompt, conversation) continue } if prompt == "/help" { - handleHelpPrompt(out) + cmdHandler.handleHelpPrompt() continue } - handleUnrecognizedPrompt(out, prompt) + cmdHandler.handleUnrecognizedPrompt(prompt) continue } @@ -315,21 +307,17 @@ func NewRunCommand() *cobra.Command { mp.UpdateRequest(&req) - sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(errOut)) + sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) sp.Start() defer sp.Stop() - resp, err := client.GetChatCompletionStream(ctx, req) - if err != nil { - return err - } - - defer resp.Reader.Close() + reader, err := cmdHandler.getChatCompletionStreamReader(req) + defer reader.Close() messageBuilder := strings.Builder{} for { - completion, err := resp.Reader.Read() + completion, err := reader.Read() if err != nil { if errors.Is(err, io.EOF) { break @@ -340,14 +328,14 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { - err = handleCompletionChoice(choice, messageBuilder, out, terminal) + err = cmdHandler.handleCompletionChoice(choice, messageBuilder) if err != nil { return err } } } - util.WriteToOut(out, "\n") + util.WriteToOut(cmdHandler.out, "\n") _, err = messageBuilder.WriteString("\n") if err != nil { return err @@ -372,11 +360,48 @@ func NewRunCommand() *cobra.Command { return cmd } -func getModelNameFromArgs(args []string, models []*azuremodels.ModelSummary) (string, error) { +type runCommandHandler struct { + ctx context.Context + terminal term.Term + out io.Writer + errOut io.Writer + client *azuremodels.Client + args []string +} + +func newRunCommandHandler(cmd *cobra.Command, args []string) *runCommandHandler { + terminal := term.FromEnv() + out := terminal.Out() + token, _ := auth.TokenForHost("github.com") + if token == "" { + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + return nil + } + return &runCommandHandler{ + ctx: cmd.Context(), + terminal: terminal, + out: out, + args: args, + errOut: terminal.ErrOut(), + client: azuremodels.NewClient(token), + } +} + +func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { + models, err := h.client.ListModels(h.ctx) + if err != nil { + return nil, err + } + + ux.SortModels(models) + return models, nil +} + +func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSummary) (string, error) { modelName := "" switch { - case len(args) == 0: + case len(h.args) == 0: // Need to prompt for a model prompt := &survey.Select{ Message: "Select a model:", @@ -395,8 +420,8 @@ func getModelNameFromArgs(args []string, models []*azuremodels.ModelSummary) (st return "", err } - case len(args) >= 1: - modelName = args[0] + case len(h.args) >= 1: + modelName = h.args[0] } return validateModelName(modelName, models) @@ -425,27 +450,35 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return modelName, nil } -func handleParametersPrompt(out io.Writer, conversation Conversation, mp ModelParameters) { - util.WriteToOut(out, "Current parameters:\n") +func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { + resp, err := h.client.GetChatCompletionStream(h.ctx, req) + if err != nil { + return nil, err + } + return resp.Reader, nil +} + +func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { + util.WriteToOut(h.out, "Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - util.WriteToOut(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + util.WriteToOut(h.out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - util.WriteToOut(out, "\n") - util.WriteToOut(out, "System Prompt:\n") + util.WriteToOut(h.out, "\n") + util.WriteToOut(h.out, "System Prompt:\n") if conversation.systemPrompt != "" { - util.WriteToOut(out, " "+conversation.systemPrompt+"\n") + util.WriteToOut(h.out, " "+conversation.systemPrompt+"\n") } else { - util.WriteToOut(out, " \n") + util.WriteToOut(h.out, " \n") } } -func handleResetPrompt(out io.Writer, conversation Conversation) { +func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { conversation.Reset() - util.WriteToOut(out, "Reset chat history\n") + util.WriteToOut(h.out, "Reset chat history\n") } -func handleSetPrompt(out io.Writer, prompt string, mp ModelParameters) { +func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { parts := strings.Split(prompt, " ") if len(parts) == 3 { name := parts[1] @@ -453,36 +486,36 @@ func handleSetPrompt(out io.Writer, prompt string, mp ModelParameters) { err := mp.SetParameterByName(name, value) if err != nil { - util.WriteToOut(out, err.Error()+"\n") + util.WriteToOut(h.out, err.Error()+"\n") return } - util.WriteToOut(out, "Set "+name+" to "+value+"\n") + util.WriteToOut(h.out, "Set "+name+" to "+value+"\n") } else { - util.WriteToOut(out, "Invalid /set syntax. Usage: /set \n") + util.WriteToOut(h.out, "Invalid /set syntax. Usage: /set \n") } } -func handleSystemPrompt(out io.Writer, prompt string, conversation Conversation) { +func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - util.WriteToOut(out, "Updated system prompt\n") + util.WriteToOut(h.out, "Updated system prompt\n") } -func handleHelpPrompt(out io.Writer) { - util.WriteToOut(out, "Commands:\n") - util.WriteToOut(out, " /bye, /exit, /quit - Exit the chat\n") - util.WriteToOut(out, " /parameters - Show current model parameters\n") - util.WriteToOut(out, " /reset, /clear - Reset chat context\n") - util.WriteToOut(out, " /set - Set a model parameter\n") - util.WriteToOut(out, " /system-prompt - Set the system prompt\n") - util.WriteToOut(out, " /help - Show this help message\n") +func (h *runCommandHandler) handleHelpPrompt() { + util.WriteToOut(h.out, "Commands:\n") + util.WriteToOut(h.out, " /bye, /exit, /quit - Exit the chat\n") + util.WriteToOut(h.out, " /parameters - Show current model parameters\n") + util.WriteToOut(h.out, " /reset, /clear - Reset chat context\n") + util.WriteToOut(h.out, " /set - Set a model parameter\n") + util.WriteToOut(h.out, " /system-prompt - Set the system prompt\n") + util.WriteToOut(h.out, " /help - Show this help message\n") } -func handleUnrecognizedPrompt(out io.Writer, prompt string) { - util.WriteToOut(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") +func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { + util.WriteToOut(h.out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") } -func handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder, out io.Writer, terminal term.Term) error { +func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { // Streamed responses from the OpenAI API have their data in `.Delta`, while // non-streamed responses use `.Message`, so let's support both if choice.Delta != nil && choice.Delta.Content != nil { @@ -491,18 +524,18 @@ func handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder string if err != nil { return err } - util.WriteToOut(out, *content) + util.WriteToOut(h.out, *content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content _, err := messageBuilder.WriteString(*content) if err != nil { return err } - util.WriteToOut(out, *content) + util.WriteToOut(h.out, *content) } // Introduce a small delay in between response tokens to better simulate a conversation - if terminal.IsTerminalOutput() { + if h.terminal.IsTerminalOutput() { time.Sleep(10 * time.Millisecond) } From 166451c0fe2e7118996d5d16f541bc63955a7e6f Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:09:44 -0500 Subject: [PATCH 17/20] Handle some unused writes and vars --- cmd/run/run.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index abc6cae..9de9e2c 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -285,7 +285,7 @@ func NewRunCommand() *cobra.Command { } if strings.HasPrefix(prompt, "/system-prompt ") { - cmdHandler.handleSystemPrompt(prompt, conversation) + conversation = cmdHandler.handleSystemPrompt(prompt, conversation) continue } @@ -312,6 +312,9 @@ func NewRunCommand() *cobra.Command { defer sp.Stop() reader, err := cmdHandler.getChatCompletionStreamReader(req) + if err != nil { + return err + } defer reader.Close() messageBuilder := strings.Builder{} @@ -496,9 +499,10 @@ func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { } } -func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) { +func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") util.WriteToOut(h.out, "Updated system prompt\n") + return conversation } func (h *runCommandHandler) handleHelpPrompt() { From 24d70db1eafc606b4dedd3fc72f59fd5acf1afc4 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:14:07 -0500 Subject: [PATCH 18/20] Disable the last couple linters for now This PR is big enough already! --- cmd/run/run.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/run/run.go b/cmd/run/run.go index 9de9e2c..8dd6719 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -309,12 +309,14 @@ func NewRunCommand() *cobra.Command { sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) sp.Start() + //nolint:deferInLoop // TODO defer sp.Stop() reader, err := cmdHandler.getChatCompletionStreamReader(req) if err != nil { return err } + //nolint:deferInLoop // TODO defer reader.Close() messageBuilder := strings.Builder{} From 4b77a2307bcd096a5c0aa080e74a72392b22ab42 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:17:17 -0500 Subject: [PATCH 19/20] Disable the right linter --- cmd/run/run.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 8dd6719..04101c6 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -309,14 +309,14 @@ func NewRunCommand() *cobra.Command { sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) sp.Start() - //nolint:deferInLoop // TODO + //nolint:gocritic // TODO defer sp.Stop() reader, err := cmdHandler.getChatCompletionStreamReader(req) if err != nil { return err } - //nolint:deferInLoop // TODO + //nolint:gocritic // TODO defer reader.Close() messageBuilder := strings.Builder{} From a0d2e2d9714796b8a8edfa821deb022916472905 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Thu, 10 Oct 2024 17:21:38 -0500 Subject: [PATCH 20/20] revive linter also dislikes it --- cmd/run/run.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 04101c6..bc12d03 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -309,14 +309,14 @@ func NewRunCommand() *cobra.Command { sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) sp.Start() - //nolint:gocritic // TODO + //nolint:gocritic,revive // TODO defer sp.Stop() reader, err := cmdHandler.getChatCompletionStreamReader(req) if err != nil { return err } - //nolint:gocritic // TODO + //nolint:gocritic,revive // TODO defer reader.Close() messageBuilder := strings.Builder{}