Skip to content

Commit

Permalink
[Go] Make dotprompt take in jsonschema.Schema instead of `map[strin…
Browse files Browse the repository at this point in the history
…g]any` for better ergonomics. (#645)
  • Loading branch information
apascal07 authored Jul 18, 2024
1 parent a7ad569 commit 3318226
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
9 changes: 6 additions & 3 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ type Config struct {
// Details for the model.
GenerationConfig *ai.GenerationCommonConfig

InputSchema *jsonschema.Schema // schema for input variables
VariableDefaults map[string]any // default input variable values
// Schema for input variables.
InputSchema *jsonschema.Schema

// Default input variable values
VariableDefaults map[string]any

// Desired output format.
OutputFormat ai.OutputFormat

// Desired output schema, for JSON output.
OutputSchema map[string]any // TODO: use *jsonschema.Schema
OutputSchema *jsonschema.Schema

// Arbitrary metadata.
Metadata map[string]any
Expand Down
8 changes: 2 additions & 6 deletions go/plugins/dotprompt/dotprompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,8 @@ func TestPrompts(t *testing.T) {
t.Errorf("unexpected output schema: %v", prompt.OutputSchema)
}
} else {
var output map[string]any
if err := json.Unmarshal([]byte(test.output), &output); err != nil {
t.Fatalf("JSON unmarshal of %q failed: %v", test.output, err)
}
if diff := cmp.Diff(output, prompt.OutputSchema); diff != "" {
t.Errorf("output schema mismatch (-want, +got):\n%s", diff)
if diff := cmpSchema(t, prompt.OutputSchema, test.output); diff != "" {
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)
}
}
})
Expand Down
15 changes: 14 additions & 1 deletion go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package dotprompt

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -117,9 +118,21 @@ func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.GenerateReque

req.Config = p.GenerationConfig

var outputSchema map[string]any
if p.OutputSchema != nil {
jsonBytes, err := p.OutputSchema.MarshalJSON()
if err != nil {
return nil, fmt.Errorf("failed to marshal output schema JSON: %w", err)
}
err = json.Unmarshal(jsonBytes, &outputSchema)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal output schema JSON: %w", err)
}
}

req.Output = &ai.GenerateRequestOutput{
Format: p.OutputFormat,
Schema: p.OutputSchema,
Schema: outputSchema,
}

req.Tools = p.Tools
Expand Down
17 changes: 2 additions & 15 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ package main

import (
"context"
"encoding/json"
"fmt"
"log"

Expand Down Expand Up @@ -172,24 +171,12 @@ func main() {
return text, nil
})

schema := r.Reflect(simpleGreetingOutput{})
jsonBytes, err := schema.MarshalJSON()
if err != nil {
log.Fatal(err)
}

var outputSchema map[string]any
err = json.Unmarshal(jsonBytes, &outputSchema)
if err != nil {
log.Fatal(err)
}

simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate,
dotprompt.Config{
Model: g,
InputSchema: jsonschema.Reflect(simpleGreetingInput{}),
InputSchema: r.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatJSON,
OutputSchema: outputSchema,
OutputSchema: r.Reflect(simpleGreetingOutput{}),
},
)
if err != nil {
Expand Down

0 comments on commit 3318226

Please sign in to comment.