Skip to content

Commit

Permalink
add --json option to psort, should help minimize 'prompt injection'-t…
Browse files Browse the repository at this point in the history
…ype errors
  • Loading branch information
ozreact committed May 30, 2023
1 parent 2b51231 commit 4c12245
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 10 deletions.
6 changes: 6 additions & 0 deletions cmd/ambrosia/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ func main() {
EnvVars: []string{"AMBROSIA_FIELDS", "FIELDS"},
Usage: "the json `FIELD`(s) to use for prompts. All fields used in random order if not specified.",
},
&cli.BoolFlag{
Name: "json",
Aliases: []string{"j"},
EnvVars: []string{"AMBROSIA_JSON", "JSON"},
Usage: "send data portion of prompt as a json object, instead of a string with fields",
},
&cli.BoolFlag{
Name: "include-resp",
Aliases: []string{"ir"},
Expand Down
18 changes: 18 additions & 0 deletions internal/datum.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"encoding/json"
"fmt"
"reflect"
"strings"
Expand All @@ -22,6 +23,23 @@ func (d datum) String(keys []string, fields bool) string {
return str.String()
}

func (d datum) JSON(keys []string) ([]byte, error) {
selected := make(datum)

for _, key := range keys {
if val, ok := d[key]; ok {
selected[key] = val
}
}

bytes, err := json.Marshal(selected)
if err != nil {
return nil, err
}

return bytes, nil
}

func isEqual(a datum, b datum) bool {
delete(a, "ambrosia")
delete(b, "ambrosia")
Expand Down
61 changes: 61 additions & 0 deletions internal/dedupe_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package internal

import (
"encoding/json"
"flag"
"os"
"reflect"
"strconv"
"testing"

Expand Down Expand Up @@ -215,3 +217,62 @@ func TestDedupe(t *testing.T) {
assert.Equal(t, expected, dedupe(ctx, data))
})
}

func TestDatum_JSON(t *testing.T) {
tests := []struct {
name string
d datum
keys []string
want string
wantErr bool
}{
{
name: "Test with existing keys",
d: datum{
"key1": "value1",
"key2": "value2",
"key3": "value3",
},
keys: []string{"key1", "key3"},
want: `{"key1":"value1","key3":"value3"}`,
wantErr: false,
},
{
name: "Test with non-existing keys",
d: datum{
"key1": "value1",
"key2": "value2",
"key3": "value3",
},
keys: []string{"key4", "key5"},
want: `{}`,
wantErr: false,
},
{
name: "Test with empty datum",
d: datum{},
keys: []string{"key1", "key2"},
want: `{}`,
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.JSON(tt.keys)
if (err != nil) != tt.wantErr {
t.Errorf("datum.JSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
var gotMap map[string]interface{}
json.Unmarshal([]byte(got), &gotMap)

var wantMap map[string]interface{}
json.Unmarshal([]byte(tt.want), &wantMap)

if !reflect.DeepEqual(gotMap, wantMap) {
t.Errorf("datum.JSON() = %v, want %v", got, tt.want)
}
})
}
}
28 changes: 18 additions & 10 deletions internal/psort.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,26 @@ func submitPrompts(c *cmdCtx, queue chan<- providers.InferRequest, data []datum)
fmt.Fprintf(&b, "%s\n\n", c.c.String("instruction"))
}

// Handle 'all fields' case
if len(c.c.StringSlice("fields")) == 0 {
for field, value := range d {
fmt.Fprintf(&b, "%s: %v\n", field, value)
if c.c.Bool("json") {
jb, err := d.JSON(c.c.StringSlice("fields"))
if err != nil {
c.logger.Fatal().Err(err).Msg("error marshalling json")
}
fmt.Fprintf(&b, "%s\n", string(jb))
} else {
// Handle 'all fields' case
if len(c.c.StringSlice("fields")) == 0 {
for field, value := range d {
fmt.Fprintf(&b, "%s: %v\n", field, value)
}
}
}

// Handle specific fields
for _, field := range c.c.StringSlice("fields") {
val, ok := d[field]
if ok {
fmt.Fprintf(&b, "%s: %v\n", field, val)
// Handle specific fields
for _, field := range c.c.StringSlice("fields") {
val, ok := d[field]
if ok {
fmt.Fprintf(&b, "%s: %v\n", field, val)
}
}
}

Expand Down

0 comments on commit 4c12245

Please sign in to comment.