Skip to content

Commit

Permalink
Merge pull request #1 from github/brannon/modelCatalogV2
Browse files Browse the repository at this point in the history
Starting using the official model catalog API
  • Loading branch information
brannon authored Sep 19, 2024
2 parents c5ba675 + cb528c0 commit 25a37cb
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 22 deletions.
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Run models list",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"args": ["list"]
}
]
}
28 changes: 20 additions & 8 deletions cmd/list/list.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package list

import (
"fmt"
"io"

"github.com/cli/go-gh/v2/pkg/auth"
"github.com/cli/go-gh/v2/pkg/tableprinter"
"github.com/cli/go-gh/v2/pkg/term"
"github.com/github/gh-models/internal/azure_models"
"github.com/github/gh-models/internal/ux"
"github.com/mgutz/ansi"
"github.com/spf13/cobra"
)

var (
lightGrayUnderline = ansi.ColorFunc("white+du")
)

func NewListCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Expand All @@ -33,22 +39,28 @@ func NewListCommand() *cobra.Command {
return err
}

// For now, filter to just chat models.
// Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column.
models = ux.FilterToChatModels(models)
ux.SortModels(models)

isTTY := terminal.IsTerminalOutput()

if isTTY {
io.WriteString(out, "\n")
io.WriteString(out, fmt.Sprintf("Showing %d available chat models\n", len(models)))
io.WriteString(out, "\n")
}

width, _, _ := terminal.Size()
printer := tableprinter.New(out, terminal.IsTerminalOutput(), width)
printer := tableprinter.New(out, isTTY, width)

printer.AddHeader([]string{"Name", "Friendly Name", "Publisher"})
printer.AddHeader([]string{"Display Name", "Model Name"}, tableprinter.WithColor(lightGrayUnderline))
printer.EndRow()

for _, model := range models {
if !ux.IsChatModel(model) {
continue
}

printer.AddField(model.Name)
printer.AddField(model.FriendlyName)
printer.AddField(model.Publisher)
printer.AddField(model.Name)
printer.EndRow()
}

Expand Down
42 changes: 38 additions & 4 deletions internal/azure_models/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Client struct {

const (
prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions"
prodModelsURL = "https://models.inference.ai.azure.com/models"
prodModelsURL = "https://api.catalog.azureml.ms/asset-gallery/v1.0/models"
)

func NewClient(authToken string) *Client {
Expand Down Expand Up @@ -66,11 +66,25 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple
}

func (c *Client) ListModels() ([]*ModelSummary, error) {
httpReq, err := http.NewRequest("GET", prodModelsURL, http.NoBody)
body := bytes.NewReader([]byte(`
{
"filters": [
{ "field": "freePlayground", "values": ["true"], "operator": "eq"},
{ "field": "labels", "values": ["latest"], "operator": "eq"}
],
"order": [
{ "field": "displayName", "direction": "asc" }
]
}
`))

httpReq, err := http.NewRequest("POST", prodModelsURL, body)
if err != nil {
return nil, err
}

httpReq.Header.Set("Content-Type", "application/json")

resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
Expand All @@ -80,12 +94,32 @@ func (c *Client) ListModels() ([]*ModelSummary, error) {
return nil, c.handleHTTPError(resp)
}

var models []*ModelSummary
err = json.NewDecoder(resp.Body).Decode(&models)
decoder := json.NewDecoder(resp.Body)
decoder.UseNumber()

var searchResponse modelCatalogSearchResponse
err = decoder.Decode(&searchResponse)
if err != nil {
return nil, err
}

models := make([]*ModelSummary, 0, len(searchResponse.Summaries))
for _, summary := range searchResponse.Summaries {
inferenceTask := ""
if len(summary.InferenceTasks) > 0 {
inferenceTask = summary.InferenceTasks[0]
}

models = append(models, &ModelSummary{
ID: summary.AssetID,
Name: summary.Name,
FriendlyName: summary.DisplayName,
Task: inferenceTask,
Publisher: summary.Publisher,
Summary: summary.Summary,
})
}

return models, nil
}

Expand Down
21 changes: 20 additions & 1 deletion internal/azure_models/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package azure_models

import "github.com/github/gh-models/internal/sse"
import (
"encoding/json"

"github.com/github/gh-models/internal/sse"
)

type ChatMessageRole string

Expand Down Expand Up @@ -49,6 +53,21 @@ type ChatCompletionResponse struct {
Reader *sse.EventReader[ChatCompletion]
}

type modelCatalogSearchResponse struct {
Summaries []modelCatalogSearchSummary `json:"summaries"`
}

type modelCatalogSearchSummary struct {
AssetID string `json:"assetId"`
DisplayName string `json:"displayName"`
InferenceTasks []string `json:"inferenceTasks"`
Name string `json:"name"`
Popularity json.Number `json:"popularity"`
Publisher string `json:"publisher"`
RegistryName string `json:"registryName"`
Summary string `json:"summary"`
}

type ModelSummary struct {
ID string `json:"id"`
Name string `json:"name"`
Expand Down
10 changes: 10 additions & 0 deletions internal/ux/filtering.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,13 @@ import (
func IsChatModel(model *azure_models.ModelSummary) bool {
return model.Task == "chat-completion"
}

func FilterToChatModels(models []*azure_models.ModelSummary) []*azure_models.ModelSummary {
var chatModels []*azure_models.ModelSummary
for _, model := range models {
if IsChatModel(model) {
chatModels = append(chatModels, model)
}
}
return chatModels
}
12 changes: 7 additions & 5 deletions internal/ux/sorting.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ux
import (
"slices"
"sort"
"strings"

"github.com/github/gh-models/internal/azure_models"
)
Expand All @@ -22,11 +23,12 @@ func SortModels(models []*azure_models.ModelSummary) {
} else if !isFeaturedI && isFeaturedJ {
return false
} else {
// Otherwise, sort by publisher and then friendly name
if models[i].Publisher == models[j].Publisher {
return models[i].FriendlyName < models[j].FriendlyName
}
return models[i].Publisher < models[j].Publisher
// Otherwise, sort by friendly name
// Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values.
friendlyNameI := strings.ToLower(models[i].FriendlyName)
friendlyNameJ := strings.ToLower(models[j].FriendlyName)

return friendlyNameI < friendlyNameJ
}
})
}
20 changes: 16 additions & 4 deletions script/build
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@ function build {
go build -o ${ROOT}/${BINARY} ${ROOT}/main.go
}

GOOS=windows GOARCH=amd64 build
GOOS=linux GOARCH=amd64 build
GOOS=darwin GOARCH=amd64 build
GOOS=darwin GOARCH=arm64 build
OS=$1

if [[ "$OS" == "windows" || "$OS" == "all" ]]; then
GOOS=windows GOARCH=amd64 build
fi

if [[ "$OS" == "linux" || "$OS" == "all" ]]; then
GOOS=linux GOARCH=amd64 build
fi

if [[ "$OS" == "darwin" || "$OS" == "all" ]]; then
GOOS=darwin GOARCH=amd64 build
GOOS=darwin GOARCH=arm64 build
fi

# Always build the "local" version, which defaults to the current OS/arch
build

0 comments on commit 25a37cb

Please sign in to comment.