Skip to content

Commit

Permalink
✨ Restore embeddings command
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Mar 1, 2023
1 parent 45af680 commit 7f0a3b5
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .goreleaser.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ brews:
homepage: "https://github.com/go-go-golems/pinocchio"
tap:
owner: go-go-golems
name: homebrew-go-go
name: homebrew-go-go-go
token: "{{ .Env.TAP_GITHUB_TOKEN }}"

# modelines, feel free to remove those if you don't want/use them:
Expand Down
254 changes: 161 additions & 93 deletions cmd/pinocchio/cmds/openai/embedding.go
Original file line number Diff line number Diff line change
@@ -1,95 +1,163 @@
package openai

//var EmbeddingsCmd = &cobra.Command{
// Use: "embeddings",
// Short: "Compute embeddings for a series of files",
// Args: cobra.MinimumNArgs(1),
// Run: func(cmd *cobra.Command, args []string) {
// user, _ := cmd.PersistentFlags().GetString("user")
//
// prompts := []string{}
//
// for _, file := range args {
//
// if file == "-" {
// file = "/dev/stdin"
// }
//
// f, err := os.ReadFile(file)
// cobra.CheckErr(err)
//
// prompts = append(prompts, string(f))
// }
//
// // TODO(manuel, 2023-01-28) actually I don't think it's a good idea to go through the stepfactory here
// // we just want to have the RAW api access with all its outputs
//
// clientSettings, err := openai.NewClientSettingsFromCobra(cmd)
// cobra.CheckErr(err)
//
// err = completionStepFactory.UpdateFromParameters(cmd)
// cobra.CheckErr(err)
//
// client, err := clientSettings.CreateClient()
// cobra.CheckErr(err)
//
// engine, _ := cmd.Flags().GetString("engine")
//
// ctx := context.Background()
// resp, err := client.Embeddings(ctx, gpt3.EmbeddingsRequest{
// Input: prompts,
// Model: engine,
// User: user,
// })
// cobra.CheckErr(err)
//
// printUsage, _ := cmd.Flags().GetBool("print-usage")
// usage := resp.Usage
// evt := log.Debug()
// if printUsage {
// evt = log.Info()
// }
// evt.
// Int("prompt-tokens", usage.PromptTokens).
// Int("total-tokens", usage.TotalTokens).
// Msg("Usage")
//
// gp, of, err := cli.SetupProcessor(cmd)
// cobra.CheckErr(err)
//
// printRawResponse, _ := cmd.Flags().GetBool("print-raw-response")
//
// if printRawResponse {
// // serialize resp to json
// rawResponse, err := json.MarshalIndent(resp, "", " ")
// cobra.CheckErr(err)
//
// // deserialize to map[string]interface{}
// var rawResponseMap map[string]interface{}
// err = json.Unmarshal(rawResponse, &rawResponseMap)
// cobra.CheckErr(err)
//
// err = gp.ProcessInputObject(rawResponseMap)
// cobra.CheckErr(err)
//
// } else {
// for _, embedding := range resp.Data {
// row := map[string]interface{}{
// "object": embedding.Object,
// "embedding": embedding.Embedding,
// "index": embedding.Index,
// }
// err = gp.ProcessInputObject(row)
// cobra.CheckErr(err)
// }
// }
//
// s, err := of.Output()
// if err != nil {
// _, _ = fmt.Fprintf(os.Stderr, "Error rendering output: %s\n", err)
// os.Exit(1)
// }
// fmt.Print(s)
// cobra.CheckErr(err)
// },
//}
import (
"context"
"encoding/json"
"github.com/PullRequestInc/go-gpt3"
"github.com/go-go-golems/geppetto/pkg/steps/openai"
"github.com/go-go-golems/glazed/pkg/cli"
"github.com/go-go-golems/glazed/pkg/cmds"
"github.com/go-go-golems/glazed/pkg/cmds/layers"
"github.com/go-go-golems/glazed/pkg/cmds/parameters"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"os"
)

type EmbeddingsCommand struct {
description *cmds.CommandDescription
}

func NewEmbeddingsCommand() (*EmbeddingsCommand, error) {
glazedParameterLayer, err := cli.NewGlazedParameterLayers(
cli.WithOutputParameterLayerOptions(
layers.WithDefaults(map[string]interface{}{
"output": "json",
},
),
),
)
if err != nil {
return nil, err
}
openaiParameterLayer, err := openai.NewClientParameterLayer()
if err != nil {
return nil, err
}
completionParameterLayer, err := openai.NewCompletionParameterLayer()
if err != nil {
return nil, err
}

return &EmbeddingsCommand{
description: cmds.NewCommandDescription(
"embeddings",
cmds.WithShort("send a prompt to the embeddings API"),
cmds.WithFlags(
parameters.NewParameterDefinition(
"print-usage",
parameters.ParameterTypeBool,
parameters.WithHelp("print usage"),
parameters.WithDefault(false),
),
parameters.NewParameterDefinition(
"print-raw-response",
parameters.ParameterTypeBool,
parameters.WithHelp("print raw response as object"),
parameters.WithDefault(false),
),
parameters.NewParameterDefinition(
"model",
parameters.ParameterTypeString,
parameters.WithHelp("model to use"),
parameters.WithDefault("text-embedding-ada-002"),
),
),
cmds.WithArguments(
parameters.NewParameterDefinition(
"input-files",
parameters.ParameterTypeStringList,
parameters.WithRequired(true),
),
),
cmds.WithLayers(
glazedParameterLayer,
completionParameterLayer,
openaiParameterLayer,
),
),
}, nil

}

func (c *EmbeddingsCommand) Description() *cmds.CommandDescription {
return c.description
}

func (c *EmbeddingsCommand) Run(
ctx context.Context,
parsedLayers map[string]*layers.ParsedParameterLayer,
ps map[string]interface{},
gp *cmds.GlazeProcessor,
) error {
user, _ := ps["user"].(string)

inputFiles, _ := ps["input-files"].([]string)
prompts := []string{}

for _, file := range inputFiles {

if file == "-" {
file = "/dev/stdin"
}

f, err := os.ReadFile(file)
cobra.CheckErr(err)

prompts = append(prompts, string(f))
}

clientSettings, err := openai.NewClientSettingsFromParameters(ps)
cobra.CheckErr(err)

client, err := clientSettings.CreateClient()
cobra.CheckErr(err)

engine, _ := ps["model"].(string)

resp, err := client.Embeddings(ctx, gpt3.EmbeddingsRequest{
Input: prompts,
Model: engine,
User: user,
})
cobra.CheckErr(err)

printUsage, _ := ps["print-usage"].(bool)
printRawResponse, _ := ps["print-raw-response"].(bool)

usage := resp.Usage
evt := log.Debug()
if printUsage {
evt = log.Info()
}
evt.
Int("prompt-tokens", usage.PromptTokens).
Int("total-tokens", usage.TotalTokens).
Msg("Usage")

if printRawResponse {
// serialize resp to json
rawResponse, err := json.MarshalIndent(resp, "", " ")
cobra.CheckErr(err)

// deserialize to map[string]interface{}
var rawResponseMap map[string]interface{}
err = json.Unmarshal(rawResponse, &rawResponseMap)
cobra.CheckErr(err)

err = gp.ProcessInputObject(rawResponseMap)
cobra.CheckErr(err)
} else {
for _, embedding := range resp.Data {
row := map[string]interface{}{
"object": embedding.Object,
"embedding": embedding.Embedding,
"index": embedding.Index,
}
err = gp.ProcessInputObject(row)
cobra.CheckErr(err)
}
}

return nil
}
13 changes: 5 additions & 8 deletions cmd/pinocchio/cmds/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,11 @@ func init() {
cobra.CheckErr(err)
OpenaiCmd.AddCommand(cobraEngineInfoCommand)

//EmbeddingsCmd.Flags().Bool("print-usage", false, "print usage")
//EmbeddingsCmd.Flags().Bool("print-raw-response", false, "print raw response as object")
//err = cli.AddFlags(EmbeddingsCmd, cli.NewFlagsDefaults())
//if err != nil {
// panic(err)
//}
//EmbeddingsCmd.Flags().String("engine", gpt3.TextDavinci002Engine, "engine to use")
//OpenaiCmd.AddCommand(EmbeddingsCmd)
embeddingsCommand, err := NewEmbeddingsCommand()
cobra.CheckErr(err)
cobraEmbeddingsCommand, err := cli.BuildCobraCommand(embeddingsCommand)
cobra.CheckErr(err)
OpenaiCmd.AddCommand(cobraEmbeddingsCommand)

err = cli.AddGlazedProcessorFlagsToCobraCommand(FamiliesCmd)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ require (
github.com/charmbracelet/bubbletea v0.23.1
github.com/charmbracelet/lipgloss v0.6.0
github.com/go-go-golems/clay v0.0.6
github.com/go-go-golems/glazed v0.2.19
github.com/go-go-golems/glazed v0.2.24
github.com/go-go-golems/parka v0.2.4
github.com/mb0/glob v0.0.0-20160210091149-1eb79d2de6c4
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.29.0
github.com/spf13/cobra v1.6.1
Expand Down Expand Up @@ -61,7 +62,6 @@ require (
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mb0/glob v0.0.0-20160210091149-1eb79d2de6c4 // indirect
github.com/microcosm-cc/bluemonday v1.0.21 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-go-golems/clay v0.0.6 h1:wMOByY3io8Pk8P/K4L4RkFvefT3eMOo9ndTF+Dr07X0=
github.com/go-go-golems/clay v0.0.6/go.mod h1:5xTh9jOtcCzPPbMY9ZBEn4pd/VmXjv9W7+MKqmjog6s=
github.com/go-go-golems/glazed v0.2.19 h1:p+ephsK0AJPjofhv5F1aj84edoNQtQNzW175/Wq3BWE=
github.com/go-go-golems/glazed v0.2.19/go.mod h1:Nf4vx3TF7hrt66F1TALSOcvIvPBnq38LjKUy5HyBS6g=
github.com/go-go-golems/glazed v0.2.24 h1:SNDRb0IbrwDSbHBFP/a9C9Pb+yuLN5PH/aQH2HaWFSc=
github.com/go-go-golems/glazed v0.2.24/go.mod h1:Nf4vx3TF7hrt66F1TALSOcvIvPBnq38LjKUy5HyBS6g=
github.com/go-go-golems/parka v0.2.4 h1:BrCKtKC45XzROnbx8HUVtV57Ko33RmZyC+D7wtxxqPc=
github.com/go-go-golems/parka v0.2.4/go.mod h1:lBfAUOgP2ZEDgFYg0nrphenkReptw5+uXmH8GCAMRis=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
Expand Down

0 comments on commit 7f0a3b5

Please sign in to comment.