diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..cc5ebdc --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,36 @@ +name: "go-linter" + +on: + pull_request: + merge_group: + workflow_dispatch: + +permissions: + contents: read + +jobs: + lint: + strategy: + fail-fast: false + runs-on: ubuntu-latest-xl + env: + GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPRIVATE: "" + GONOPROXY: "" + GONOSUMDB: github.com/github/* + steps: + - uses: actions/setup-go@v5 + with: + go-version: ${{ vars.GOVERSION }} + check-latest: true + - uses: actions/checkout@v4 + - name: Configure Go private module access + run: | + echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc + - name: Lint + # This also does checkout, setup-go, and proxy setup. + uses: github/go-linter@v1.2.1 + with: + strict: true + go-version: ${{ vars.GOVERSION }} + goproxy-token: ${{secrets.GOPROXY_TOKEN}} diff --git a/README.md b/README.md index b8dc4e5..138369c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Use the GitHub Models service from the CLI! ### Prerequisites -The extension requires the `gh` CLI to be installed and in the PATH. The extension also requires the user have authenticated via `gh auth`. +The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and in the `PATH`. The extension also requires the user have authenticated via `gh auth`. ### Installing @@ -73,4 +73,4 @@ git tag v0.0.x main git push origin tag v0.0.x ``` -This will trigger the `release` action that runs the actual production build. \ No newline at end of file +This will trigger the `release` action that runs the actual production build. diff --git a/cmd/list/list.go b/cmd/list/list.go index 8ea9e11..6ab3088 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -1,14 +1,15 @@ +// Package list provides a gh command to list available models. package list import ( "fmt" - "io" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/mgutz/ansi" "github.com/spf13/cobra" ) @@ -17,6 +18,7 @@ var ( lightGrayUnderline = ansi.ColorFunc("white+du") ) +// NewListCommand returns a new command to list available GitHub models. func NewListCommand() *cobra.Command { cmd := &cobra.Command{ Use: "list", @@ -28,28 +30,29 @@ func NewListCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } - client := azure_models.NewClient(token) + client := azuremodels.NewClient(token) + ctx := cmd.Context() - models, err := client.ListModels() + models, err := client.ListModels(ctx) if err != nil { return err } // For now, filter to just chat models. // Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column. - models = ux.FilterToChatModels(models) + models = filterToChatModels(models) ux.SortModels(models) isTTY := terminal.IsTerminalOutput() if isTTY { - io.WriteString(out, "\n") - io.WriteString(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - io.WriteString(out, "\n") + util.WriteToOut(out, "\n") + util.WriteToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) + util.WriteToOut(out, "\n") } width, _, _ := terminal.Size() @@ -75,3 +78,13 @@ func NewListCommand() *cobra.Command { return cmd } + +func filterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { + var chatModels []*azuremodels.ModelSummary + for _, model := range models { + if ux.IsChatModel(model) { + chatModels = append(chatModels, model) + } + } + return chatModels +} diff --git a/cmd/root.go b/cmd/root.go index 298ab39..3e8caf5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,13 +1,16 @@ +// Package cmd represents the base command when called without any subcommands. package cmd import ( + "strings" + "github.com/github/gh-models/cmd/list" "github.com/github/gh-models/cmd/run" "github.com/github/gh-models/cmd/view" "github.com/spf13/cobra" - "strings" ) +// NewRootCommand returns a new root command for the gh-models extension. func NewRootCommand() *cobra.Command { cmd := &cobra.Command{ Use: "models", diff --git a/cmd/run/run.go b/cmd/run/run.go index 064b914..bc12d03 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -1,7 +1,9 @@ +// Package run provides a gh command to run a GitHub model. package run import ( "bufio" + "context" "errors" "fmt" "io" @@ -14,18 +16,22 @@ import ( "github.com/briandowns/spinner" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" ) +// ModelParameters represents the parameters that can be set for a model run. type ModelParameters struct { maxTokens *int temperature *float64 topP *float64 } +// FormatParameter returns a string representation of the parameter value. func (mp *ModelParameters) FormatParameter(name string) string { switch name { case "max-tokens": @@ -47,6 +53,7 @@ func (mp *ModelParameters) FormatParameter(name string) string { return "" } +// PopulateFromFlags populates the model parameters from the given flags. func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { maxTokensString, err := flags.GetString("max-tokens") if err != nil { @@ -57,7 +64,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) } temperatureString, err := flags.GetString("temperature") @@ -69,7 +76,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) } topPString, err := flags.GetString("top-p") @@ -81,34 +88,35 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) } return nil } -func (mp *ModelParameters) SetParameterByName(name string, value string) error { +// SetParameterByName sets the parameter with the given name to the given value. +func (mp *ModelParameters) SetParameterByName(name, value string) error { switch name { case "max-tokens": maxTokens, err := strconv.Atoi(value) if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) case "temperature": temperature, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) case "top-p": topP, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) default: return errors.New("unknown parameter '" + name + "'. Supported parameters: max-tokens, temperature, top-p") @@ -117,37 +125,41 @@ func (mp *ModelParameters) SetParameterByName(name string, value string) error { return nil } -func (mp *ModelParameters) UpdateRequest(req *azure_models.ChatCompletionOptions) { +// UpdateRequest updates the given request with the model parameters. +func (mp *ModelParameters) UpdateRequest(req *azuremodels.ChatCompletionOptions) { req.MaxTokens = mp.maxTokens req.Temperature = mp.temperature req.TopP = mp.topP } +// Conversation represents a conversation between the user and the model. type Conversation struct { - messages []azure_models.ChatMessage + messages []azuremodels.ChatMessage systemPrompt string } -func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content string) { - c.messages = append(c.messages, azure_models.ChatMessage{ - Content: azure_models.Ptr(content), +// AddMessage adds a message to the conversation. +func (c *Conversation) AddMessage(role azuremodels.ChatMessageRole, content string) { + c.messages = append(c.messages, azuremodels.ChatMessage{ + Content: util.Ptr(content), Role: role, }) } -func (c *Conversation) GetMessages() []azure_models.ChatMessage { +// GetMessages returns the messages in the conversation. +func (c *Conversation) GetMessages() []azuremodels.ChatMessage { length := len(c.messages) if c.systemPrompt != "" { length++ } - messages := make([]azure_models.ChatMessage, length) + messages := make([]azuremodels.ChatMessage, length) startIndex := 0 if c.systemPrompt != "" { - messages[0] = azure_models.ChatMessage{ - Content: azure_models.Ptr(c.systemPrompt), - Role: azure_models.ChatMessageRoleSystem, + messages[0] = azuremodels.ChatMessage{ + Content: util.Ptr(c.systemPrompt), + Role: azuremodels.ChatMessageRoleSystem, } startIndex++ } @@ -159,6 +171,7 @@ func (c *Conversation) GetMessages() []azure_models.ChatMessage { return messages } +// Reset removes messages from the conversation. func (c *Conversation) Reset() { c.messages = nil } @@ -176,73 +189,26 @@ func isPipe(r io.Reader) bool { return false } +// NewRunCommand returns a new gh command for running a model. func NewRunCommand() *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - errOut := terminal.ErrOut() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + cmdHandler := newRunCommandHandler(cmd, args) + if cmdHandler == nil { return nil } - client := azure_models.NewClient(token) - - models, err := client.ListModels() + models, err := cmdHandler.loadModels() if err != nil { return err } - ux.SortModels(models) - - modelName := "" - switch { - case len(args) == 0: - // Need to prompt for a model - prompt := &survey.Select{ - Message: "Select a model:", - Options: []string{}, - } - - for _, model := range models { - if !ux.IsChatModel(model) { - continue - } - prompt.Options = append(prompt.Options, model.FriendlyName) - } - - err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) - if err != nil { - return err - } - - case len(args) >= 1: - modelName = args[0] - } - - noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." - - if modelName == "" { - return errors.New(noMatchErrorMessage) - } - - foundMatch := false - for _, model := range models { - if model.HasName(modelName) { - modelName = model.Name - foundMatch = true - break - } - } - - if !foundMatch { - return errors.New(noMatchErrorMessage) + modelName, err := cmdHandler.getModelNameFromArgs(models) + if err != nil { + return err } initialPrompt := "" @@ -304,91 +270,59 @@ func NewRunCommand() *cobra.Command { } if prompt == "/parameters" { - io.WriteString(out, "Current parameters:\n") - names := []string{"max-tokens", "temperature", "top-p"} - for _, name := range names { - io.WriteString(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) - } - io.WriteString(out, "\n") - io.WriteString(out, "System Prompt:\n") - if conversation.systemPrompt != "" { - io.WriteString(out, " "+conversation.systemPrompt+"\n") - } else { - io.WriteString(out, " \n") - } + cmdHandler.handleParametersPrompt(conversation, mp) continue } if prompt == "/reset" || prompt == "/clear" { - conversation.Reset() - io.WriteString(out, "Reset chat history\n") + cmdHandler.handleResetPrompt(conversation) continue } if strings.HasPrefix(prompt, "/set ") { - parts := strings.Split(prompt, " ") - if len(parts) == 3 { - name := parts[1] - value := parts[2] - - err := mp.SetParameterByName(name, value) - if err != nil { - io.WriteString(out, err.Error()+"\n") - continue - } - - io.WriteString(out, "Set "+name+" to "+value+"\n") - } else { - io.WriteString(out, "Invalid /set syntax. Usage: /set \n") - } + cmdHandler.handleSetPrompt(prompt, mp) continue } if strings.HasPrefix(prompt, "/system-prompt ") { - conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - io.WriteString(out, "Updated system prompt\n") + conversation = cmdHandler.handleSystemPrompt(prompt, conversation) continue } if prompt == "/help" { - io.WriteString(out, "Commands:\n") - io.WriteString(out, " /bye, /exit, /quit - Exit the chat\n") - io.WriteString(out, " /parameters - Show current model parameters\n") - io.WriteString(out, " /reset, /clear - Reset chat context\n") - io.WriteString(out, " /set - Set a model parameter\n") - io.WriteString(out, " /system-prompt - Set the system prompt\n") - io.WriteString(out, " /help - Show this help message\n") + cmdHandler.handleHelpPrompt() continue } - io.WriteString(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + cmdHandler.handleUnrecognizedPrompt(prompt) continue } - conversation.AddMessage(azure_models.ChatMessageRoleUser, prompt) + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) - req := azure_models.ChatCompletionOptions{ + req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, } mp.UpdateRequest(&req) - sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(errOut)) + sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) sp.Start() + //nolint:gocritic,revive // TODO defer sp.Stop() - resp, err := client.GetChatCompletionStream(req) + reader, err := cmdHandler.getChatCompletionStreamReader(req) if err != nil { return err } - - defer resp.Reader.Close() + //nolint:gocritic,revive // TODO + defer reader.Close() messageBuilder := strings.Builder{} for { - completion, err := resp.Reader.Read() + completion, err := reader.Read() if err != nil { if errors.Is(err, io.EOF) { break @@ -399,29 +333,20 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { - // Streamed responses from the OpenAI API have their data in `.Delta`, while - // non-streamed responses use `.Message`, so let's support both - if choice.Delta != nil && choice.Delta.Content != nil { - content := choice.Delta.Content - messageBuilder.WriteString(*content) - io.WriteString(out, *content) - } else if choice.Message != nil && choice.Message.Content != nil { - content := choice.Message.Content - messageBuilder.WriteString(*content) - io.WriteString(out, *content) - } - - // Introduce a small delay in between response tokens to better simulate a conversation - if terminal.IsTerminalOutput() { - time.Sleep(10 * time.Millisecond) + err = cmdHandler.handleCompletionChoice(choice, messageBuilder) + if err != nil { + return err } } } - io.WriteString(out, "\n") - messageBuilder.WriteString("\n") + util.WriteToOut(cmdHandler.out, "\n") + _, err = messageBuilder.WriteString("\n") + if err != nil { + return err + } - conversation.AddMessage(azure_models.ChatMessageRoleAssistant, messageBuilder.String()) + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) if singleShot { break @@ -439,3 +364,186 @@ func NewRunCommand() *cobra.Command { return cmd } + +type runCommandHandler struct { + ctx context.Context + terminal term.Term + out io.Writer + errOut io.Writer + client *azuremodels.Client + args []string +} + +func newRunCommandHandler(cmd *cobra.Command, args []string) *runCommandHandler { + terminal := term.FromEnv() + out := terminal.Out() + token, _ := auth.TokenForHost("github.com") + if token == "" { + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + return nil + } + return &runCommandHandler{ + ctx: cmd.Context(), + terminal: terminal, + out: out, + args: args, + errOut: terminal.ErrOut(), + client: azuremodels.NewClient(token), + } +} + +func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { + models, err := h.client.ListModels(h.ctx) + if err != nil { + return nil, err + } + + ux.SortModels(models) + return models, nil +} + +func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSummary) (string, error) { + modelName := "" + + switch { + case len(h.args) == 0: + // Need to prompt for a model + prompt := &survey.Select{ + Message: "Select a model:", + Options: []string{}, + } + + for _, model := range models { + if !ux.IsChatModel(model) { + continue + } + prompt.Options = append(prompt.Options, model.FriendlyName) + } + + err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) + if err != nil { + return "", err + } + + case len(h.args) >= 1: + modelName = h.args[0] + } + + return validateModelName(modelName, models) +} + +func validateModelName(modelName string, models []*azuremodels.ModelSummary) (string, error) { + noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." + + if modelName == "" { + return "", errors.New(noMatchErrorMessage) + } + + foundMatch := false + for _, model := range models { + if model.HasName(modelName) { + modelName = model.Name + foundMatch = true + break + } + } + + if !foundMatch { + return "", errors.New(noMatchErrorMessage) + } + + return modelName, nil +} + +func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { + resp, err := h.client.GetChatCompletionStream(h.ctx, req) + if err != nil { + return nil, err + } + return resp.Reader, nil +} + +func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { + util.WriteToOut(h.out, "Current parameters:\n") + names := []string{"max-tokens", "temperature", "top-p"} + for _, name := range names { + util.WriteToOut(h.out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + } + util.WriteToOut(h.out, "\n") + util.WriteToOut(h.out, "System Prompt:\n") + if conversation.systemPrompt != "" { + util.WriteToOut(h.out, " "+conversation.systemPrompt+"\n") + } else { + util.WriteToOut(h.out, " \n") + } +} + +func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { + conversation.Reset() + util.WriteToOut(h.out, "Reset chat history\n") +} + +func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { + parts := strings.Split(prompt, " ") + if len(parts) == 3 { + name := parts[1] + value := parts[2] + + err := mp.SetParameterByName(name, value) + if err != nil { + util.WriteToOut(h.out, err.Error()+"\n") + return + } + + util.WriteToOut(h.out, "Set "+name+" to "+value+"\n") + } else { + util.WriteToOut(h.out, "Invalid /set syntax. Usage: /set \n") + } +} + +func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { + conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") + util.WriteToOut(h.out, "Updated system prompt\n") + return conversation +} + +func (h *runCommandHandler) handleHelpPrompt() { + util.WriteToOut(h.out, "Commands:\n") + util.WriteToOut(h.out, " /bye, /exit, /quit - Exit the chat\n") + util.WriteToOut(h.out, " /parameters - Show current model parameters\n") + util.WriteToOut(h.out, " /reset, /clear - Reset chat context\n") + util.WriteToOut(h.out, " /set - Set a model parameter\n") + util.WriteToOut(h.out, " /system-prompt - Set the system prompt\n") + util.WriteToOut(h.out, " /help - Show this help message\n") +} + +func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { + util.WriteToOut(h.out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") +} + +func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { + // Streamed responses from the OpenAI API have their data in `.Delta`, while + // non-streamed responses use `.Message`, so let's support both + if choice.Delta != nil && choice.Delta.Content != nil { + content := choice.Delta.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + util.WriteToOut(h.out, *content) + } else if choice.Message != nil && choice.Message.Content != nil { + content := choice.Message.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + util.WriteToOut(h.out, *content) + } + + // Introduce a small delay in between response tokens to better simulate a conversation + if h.terminal.IsTerminalOutput() { + time.Sleep(10 * time.Millisecond) + } + + return nil +} diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index d3f20b7..63790f3 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -6,7 +6,7 @@ import ( "github.com/cli/cli/v2/pkg/markdown" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/mgutz/ansi" ) @@ -15,13 +15,13 @@ var ( ) type modelPrinter struct { - modelSummary *azure_models.ModelSummary - modelDetails *azure_models.ModelDetails + modelSummary *azuremodels.ModelSummary + modelDetails *azuremodels.ModelDetails printer tableprinter.TablePrinter terminalWidth int } -func newModelPrinter(summary *azure_models.ModelSummary, details *azure_models.ModelDetails, terminal term.Term) modelPrinter { +func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, terminal term.Term) modelPrinter { width, _, _ := terminal.Size() printer := tableprinter.New(terminal.Out(), terminal.IsTerminalOutput(), width) return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: width} @@ -59,7 +59,7 @@ func (p *modelPrinter) render() error { return nil } -func (p *modelPrinter) printLabelledLine(label string, value string) { +func (p *modelPrinter) printLabelledLine(label, value string) { if value == "" { return } @@ -76,7 +76,7 @@ func (p *modelPrinter) printLabelledMultiLineList(label string, values []string) p.printMultipleLinesWithLabel(label, strings.Join(values, ", ")) } -func (p *modelPrinter) printMultipleLinesWithLabel(label string, value string) { +func (p *modelPrinter) printMultipleLinesWithLabel(label, value string) { if value == "" { return } diff --git a/cmd/view/view.go b/cmd/view/view.go index f36a051..777281d 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -1,17 +1,19 @@ +// Package view provides a `gh models view` command to view details about a model. package view import ( "fmt" - "io" "github.com/AlecAivazis/survey/v2" "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) +// NewViewCommand returns a new command to view details about a model. func NewViewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", @@ -22,13 +24,14 @@ func NewViewCommand() *cobra.Command { token, _ := auth.TokenForHost("github.com") if token == "" { - io.WriteString(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } - client := azure_models.NewClient(token) + client := azuremodels.NewClient(token) + ctx := cmd.Context() - models, err := client.ListModels() + models, err := client.ListModels(ctx) if err != nil { return err } @@ -65,7 +68,7 @@ func NewViewCommand() *cobra.Command { return err } - modelDetails, err := client.GetModelDetails(modelSummary.RegistryName, modelSummary.Name, modelSummary.Version) + modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version) if err != nil { return err } @@ -84,7 +87,7 @@ func NewViewCommand() *cobra.Command { } // getModelByName returns the model with the specified name, or an error if no such model exists within the given list. -func getModelByName(modelName string, models []*azure_models.ModelSummary) (*azure_models.ModelSummary, error) { +func getModelByName(modelName string, models []*azuremodels.ModelSummary) (*azuremodels.ModelSummary, error) { for _, model := range models { if model.HasName(modelName) { return model, nil diff --git a/internal/azure_models/client.go b/internal/azuremodels/client.go similarity index 77% rename from internal/azure_models/client.go rename to internal/azuremodels/client.go index 996c270..a4b60d3 100644 --- a/internal/azure_models/client.go +++ b/internal/azuremodels/client.go @@ -1,7 +1,9 @@ -package azure_models +// Package azuremodels provides a client for interacting with the Azure models API. +package azuremodels import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -15,6 +17,7 @@ import ( "golang.org/x/text/language/display" ) +// Client provides a client for interacting with the Azure models API. type Client struct { client *http.Client token string @@ -26,6 +29,7 @@ const ( prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" ) +// NewClient returns a new client using the given auth token. func NewClient(authToken string) *Client { httpClient, _ := api.DefaultHTTPClient() return &Client{ @@ -34,7 +38,8 @@ func NewClient(authToken string) *Client { } } -func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatCompletionResponse, error) { +// GetChatCompletionStream returns a stream of chat completions for the given request. +func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { // Check if the model name is `o1-mini` or `o1-preview` if req.Model == "o1-mini" || req.Model == "o1-preview" { req.Stream = false @@ -49,7 +54,7 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple body := bytes.NewReader(bodyBytes) - httpReq, err := http.NewRequest("POST", prodInferenceURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) if err != nil { return nil, err } @@ -87,9 +92,10 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple return &chatCompletionResponse, nil } -func (c *Client) GetModelDetails(registry string, modelName string, version string) (*ModelDetails, error) { +// GetModelDetails returns the details of the specified model in a prticular registry. +func (c *Client) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) - httpReq, err := http.NewRequest("GET", url, nil) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, err } @@ -101,6 +107,8 @@ func (c *Client) GetModelDetails(registry string, modelName string, version stri return nil, err } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, c.handleHTTPError(resp) } @@ -162,7 +170,8 @@ func lowercaseStrings(input []string) []string { return output } -func (c *Client) ListModels() ([]*ModelSummary, error) { +// ListModels returns a list of available models. +func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { body := bytes.NewReader([]byte(` { "filters": [ @@ -175,7 +184,7 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { } `)) - httpReq, err := http.NewRequest("POST", prodModelsURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) if err != nil { return nil, err } @@ -187,6 +196,8 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { return nil, err } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, c.handleHTTPError(resp) } @@ -223,25 +234,45 @@ func (c *Client) ListModels() ([]*ModelSummary, error) { } func (c *Client) handleHTTPError(resp *http.Response) error { - sb := strings.Builder{} + var err error switch resp.StatusCode { case http.StatusUnauthorized: - sb.WriteString("unauthorized") + _, err = sb.WriteString("unauthorized") + if err != nil { + return err + } case http.StatusBadRequest: - sb.WriteString("bad request") + _, err = sb.WriteString("bad request") + if err != nil { + return err + } default: - sb.WriteString("unexpected response from the server: " + resp.Status) + _, err = sb.WriteString("unexpected response from the server: " + resp.Status) + if err != nil { + return err + } } body, _ := io.ReadAll(resp.Body) if len(body) > 0 { - sb.WriteString("\n") - sb.Write(body) - sb.WriteString("\n") + _, err = sb.WriteString("\n") + if err != nil { + return err + } + + _, err = sb.Write(body) + if err != nil { + return err + } + + _, err = sb.WriteString("\n") + if err != nil { + return err + } } return errors.New(sb.String()) diff --git a/internal/azure_models/types.go b/internal/azuremodels/types.go similarity index 80% rename from internal/azure_models/types.go rename to internal/azuremodels/types.go index de76846..98138fa 100644 --- a/internal/azure_models/types.go +++ b/internal/azuremodels/types.go @@ -1,4 +1,4 @@ -package azure_models +package azuremodels import ( "encoding/json" @@ -8,19 +8,25 @@ import ( "github.com/github/gh-models/internal/sse" ) +// ChatMessageRole represents the role of a chat message. type ChatMessageRole string const ( + // ChatMessageRoleAssistant represents a message from the model. ChatMessageRoleAssistant ChatMessageRole = "assistant" - ChatMessageRoleSystem ChatMessageRole = "system" - ChatMessageRoleUser ChatMessageRole = "user" + // ChatMessageRoleSystem represents a system message. + ChatMessageRoleSystem ChatMessageRole = "system" + // ChatMessageRoleUser represents a message from the user. + ChatMessageRoleUser ChatMessageRole = "user" ) +// ChatMessage represents a message from a chat thread with a model. type ChatMessage struct { Content *string `json:"content,omitempty"` Role ChatMessageRole `json:"role"` } +// ChatCompletionOptions represents available options for a chat completion request. type ChatCompletionOptions struct { MaxTokens *int `json:"max_tokens,omitempty"` Messages []ChatMessage `json:"messages"` @@ -30,27 +36,30 @@ type ChatCompletionOptions struct { TopP *float64 `json:"top_p,omitempty"` } -type ChatChoiceMessage struct { +type chatChoiceMessage struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } -type ChatChoiceDelta struct { +type chatChoiceDelta struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } +// ChatChoice represents a choice in a chat completion. type ChatChoice struct { - Delta *ChatChoiceDelta `json:"delta,omitempty"` + Delta *chatChoiceDelta `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` Index int32 `json:"index"` - Message *ChatChoiceMessage `json:"message,omitempty"` + Message *chatChoiceMessage `json:"message,omitempty"` } +// ChatCompletion represents a chat completion. type ChatCompletion struct { Choices []ChatChoice `json:"choices"` } +// ChatCompletionResponse represents a response to a chat completion request. type ChatCompletionResponse struct { Reader sse.Reader[ChatCompletion] } @@ -71,6 +80,7 @@ type modelCatalogSearchSummary struct { Summary string `json:"summary"` } +// ModelSummary includes basic information about a model. type ModelSummary struct { ID string `json:"id"` Name string `json:"name"` @@ -82,6 +92,7 @@ type ModelSummary struct { RegistryName string `json:"registry_name"` } +// HasName checks if the model has the given name. func (m *ModelSummary) HasName(name string) bool { return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name) } @@ -119,6 +130,7 @@ type modelCatalogDetailsResponse struct { } `json:"modelLimits"` } +// ModelDetails includes detailed information about a model. type ModelDetails struct { Description string `json:"description"` Evaluation string `json:"evaluation"` @@ -134,10 +146,7 @@ type ModelDetails struct { RateLimitTier string `json:"rateLimitTier"` } +// ContextLimits returns a summary of the context limits for the model. func (m *ModelDetails) ContextLimits() string { return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens) } - -func Ptr[T any](value T) *T { - return &value -} diff --git a/internal/sse/eventreader.go b/internal/sse/eventreader.go index 5eddcc8..391c3e2 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/eventreader.go @@ -1,5 +1,6 @@ // Forked from https://github.com/Azure/azure-sdk-for-go/blob/4661007ca1fd68b2e31f3502d4282904014fd731/sdk/ai/azopenai/event_reader.go#L18 +// Package sse provides a reader for Server-Sent Events (SSE) streams. package sse import ( diff --git a/internal/sse/mockeventreader.go b/internal/sse/mockeventreader.go index aa015a7..b8b4a03 100644 --- a/internal/sse/mockeventreader.go +++ b/internal/sse/mockeventreader.go @@ -7,7 +7,7 @@ import ( ) // MockEventReader is a mock implementation of the sse.EventReader. This lets us use EventReader as a common interface -// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models) +// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models). type MockEventReader[T any] struct { reader io.ReadCloser scanner *bufio.Scanner @@ -15,6 +15,7 @@ type MockEventReader[T any] struct { index int } +// NewMockEventReader creates a new MockEventReader with the given events. func NewMockEventReader[T any](events []T) *MockEventReader[T] { data := []byte{} reader := io.NopCloser(bytes.NewReader(data)) @@ -22,6 +23,7 @@ func NewMockEventReader[T any](events []T) *MockEventReader[T] { return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} } +// Read reads the next event from the stream. func (mer *MockEventReader[T]) Read() (T, error) { if mer.index >= len(mer.events) { var zero T @@ -32,6 +34,7 @@ func (mer *MockEventReader[T]) Read() (T, error) { return event, nil } +// Close closes the Reader and any applicable inner stream state. func (mer *MockEventReader[T]) Close() error { return mer.reader.Close() } diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go index 89dcc17..f456c85 100644 --- a/internal/ux/filtering.go +++ b/internal/ux/filtering.go @@ -1,19 +1,9 @@ +// Package ux provides utility functions around presentation and user experience. package ux -import ( - "github.com/github/gh-models/internal/azure_models" -) +import "github.com/github/gh-models/internal/azuremodels" -func IsChatModel(model *azure_models.ModelSummary) bool { +// IsChatModel returns true if the given model is for chat completions. +func IsChatModel(model *azuremodels.ModelSummary) bool { return model.Task == "chat-completion" } - -func FilterToChatModels(models []*azure_models.ModelSummary) []*azure_models.ModelSummary { - var chatModels []*azure_models.ModelSummary - for _, model := range models { - if IsChatModel(model) { - chatModels = append(chatModels, model) - } - } - return chatModels -} diff --git a/internal/ux/sorting.go b/internal/ux/sorting.go index c8c66d6..59b0e4e 100644 --- a/internal/ux/sorting.go +++ b/internal/ux/sorting.go @@ -5,14 +5,15 @@ import ( "sort" "strings" - "github.com/github/gh-models/internal/azure_models" + "github.com/github/gh-models/internal/azuremodels" ) var ( featuredModelNames = []string{} ) -func SortModels(models []*azure_models.ModelSummary) { +// SortModels sorts the given models in place, with featured models first, and then by friendly name. +func SortModels(models []*azuremodels.ModelSummary) { sort.Slice(models, func(i, j int) bool { // Sort featured models first, by name isFeaturedI := slices.Contains(featuredModelNames, models[i].Name) @@ -20,15 +21,17 @@ func SortModels(models []*azure_models.ModelSummary) { if isFeaturedI && !isFeaturedJ { return true - } else if !isFeaturedI && isFeaturedJ { - return false - } else { - // Otherwise, sort by friendly name - // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. - friendlyNameI := strings.ToLower(models[i].FriendlyName) - friendlyNameJ := strings.ToLower(models[j].FriendlyName) + } - return friendlyNameI < friendlyNameJ + if !isFeaturedI && isFeaturedJ { + return false } + + // Otherwise, sort by friendly name + // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. + friendlyNameI := strings.ToLower(models[i].FriendlyName) + friendlyNameJ := strings.ToLower(models[j].FriendlyName) + + return friendlyNameI < friendlyNameJ }) } diff --git a/main.go b/main.go index 23f6148..6aec694 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// Package main provides the entry point for the gh-models extension. package main import ( @@ -20,12 +21,12 @@ func main() { } func mainRun() exitCode { - cmd := cmd.NewRootCommand() + rootCmd := cmd.NewRootCommand() exitCode := exitOK ctx := context.Background() - if _, err := cmd.ExecuteContextC(ctx); err != nil { + if _, err := rootCmd.ExecuteContextC(ctx); err != nil { exitCode = exitError } diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 0000000..1856f20 --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,20 @@ +// Package util provides utility functions for the gh-models extension. +package util + +import ( + "fmt" + "io" +) + +// WriteToOut writes a message to the given io.Writer. +func WriteToOut(out io.Writer, message string) { + _, err := io.WriteString(out, message) + if err != nil { + fmt.Println("Error writing message:", err) + } +} + +// Ptr returns a pointer to the given value. +func Ptr[T any](value T) *T { + return &value +}