Skip to content

Commit

Permalink
enhance: add field-level sensitivity for prompts
Browse files Browse the repository at this point in the history
Additionally, each field can now also have a description.

This change is made such that all existing tools will work. However,
existing code will need to be updated to support the new types.

Signed-off-by: Donnie Adams <donnie@acorn.io>
  • Loading branch information
thedadams committed Feb 4, 2025
1 parent 7ee5c80 commit e0fdb70
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 25 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ require (
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee
github.com/hexops/autogold/v2 v2.2.1
github.com/hexops/valast v1.4.4
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f
github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA=
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE=
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
Expand Down
1 change: 0 additions & 1 deletion pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
DisableCache: r.DisableCache,
CredentialOverrides: r.CredentialOverride,
Input: toolInput,
CacheDir: r.CacheDir,
SubTool: r.SubTool,
Workspace: r.Workspace,
SaveChatStateFile: r.SaveChatStateFile,
Expand Down
10 changes: 5 additions & 5 deletions pkg/engine/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func mergeInputs(base, overlay string) (string, error) {
return base, nil
}

err := json.Unmarshal([]byte(base), &baseMap)
if err != nil {
return "", fmt.Errorf("failed to unmarshal base input: %w", err)
if base != "" {
if err := json.Unmarshal([]byte(base), &baseMap); err != nil {
return "", fmt.Errorf("failed to unmarshal base input: %w", err)
}
}

err = json.Unmarshal([]byte(overlay), &overlayMap)
if err != nil {
if err := json.Unmarshal([]byte(overlay), &overlayMap); err != nil {
return "", fmt.Errorf("failed to unmarshal overlay input: %w", err)
}

Expand Down
24 changes: 12 additions & 12 deletions pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,20 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
Message string `json:"message,omitempty"`
Fields string `json:"fields,omitempty"`
Fields types.Fields `json:"fields,omitempty"`
Sensitive string `json:"sensitive,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}

var fields []string
for _, env := range envs {
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {

Check failure on line 64 in pkg/prompt/prompt.go

View workflow job for this annotation

GitHub Actions / test (ubuntu-22.04)

unnecessary leading newline (whitespace)
if params.Fields != "" {
fields = strings.Split(params.Fields, ",")
}

httpPrompt := types.Prompt{
Message: params.Message,
Fields: fields,
Fields: params.Fields,
Sensitive: params.Sensitive == "true",
Metadata: params.Metadata,
}
Expand Down Expand Up @@ -102,21 +98,25 @@ func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
results := map[string]string{}
for _, f := range req.Fields {
var (
value string
msg = f
value string
msg = f.Name
sensitive = req.Sensitive
)
if f.Sensitive != nil {
sensitive = *f.Sensitive
}
if len(req.Fields) == 1 && req.Message != "" {
msg = req.Message
}
if req.Sensitive {
err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
if sensitive {
err = survey.AskOne(&survey.Password{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
} else {
err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
err = survey.AskOne(&survey.Input{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
}
if err != nil {
return "", err
}
results[f] = value
results[f.Name] = value
}

resultsStr, err := json.Marshal(results)
Expand Down
63 changes: 62 additions & 1 deletion pkg/types/prompt.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,74 @@
package types

import (
"encoding/json"
"strings"
)

const (
PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL"
PromptTokenEnvVar = "GPTSCRIPT_PROMPT_TOKEN"
)

type Prompt struct {
Message string `json:"message,omitempty"`
Fields []string `json:"fields,omitempty"`
Fields Fields `json:"fields,omitempty"`
Sensitive bool `json:"sensitive,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}

type Field struct {
Name string `json:"name,omitempty"`
Sensitive *bool `json:"sensitive,omitempty"`
Description string `json:"description,omitempty"`
}

type Fields []Field

// UnmarshalJSON will unmarshal the corresponding JSON object for Fields,
// or a comma-separated strings (for backwards compatibility).
func (f *Fields) UnmarshalJSON(b []byte) error {
if len(b) == 0 || f == nil {
return nil
}

if b[0] == '[' {
var arr []Field
if err := json.Unmarshal(b, &arr); err != nil {
return err
}
*f = arr
return nil
}

var fields string
if err := json.Unmarshal(b, &fields); err != nil {
return err
}

if fields != "" {
fieldsArr := strings.Split(fields, ",")
*f = make([]Field, 0, len(fieldsArr))
for _, field := range fieldsArr {
*f = append(*f, Field{Name: strings.TrimSpace(field)})
}
}

return nil
}

type field *Field

// UnmarshalJSON will unmarshal the corresponding JSON object for a Field,
// or a string (for backwards compatibility).
func (f *Field) UnmarshalJSON(b []byte) error {
if len(b) == 0 || f == nil {
return nil
}

if b[0] == '{' {
return json.Unmarshal(b, field(f))
}

return json.Unmarshal(b, &f.Name)
}
142 changes: 142 additions & 0 deletions pkg/types/prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package types

import (
"reflect"
"testing"
)

func TestFieldUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input []byte
expected Field
expectErr bool
}{
{
name: "valid single Field object JSON",
input: []byte(`{"name":"field1","sensitive":true,"description":"A test field"}`),
expected: Field{Name: "field1", Sensitive: boolPtr(true), Description: "A test field"},
expectErr: false,
},
{
name: "valid Field name as string",
input: []byte(`"field1"`),
expected: Field{Name: "field1"},
expectErr: false,
},
{
name: "empty input",
input: []byte(``),
expected: Field{},
expectErr: false,
},
{
name: "invalid JSON object",
input: []byte(`{"name":"field1","sensitive":"not_boolean"}`),
expected: Field{Name: "field1", Sensitive: new(bool)},
expectErr: true,
},
{
name: "extra unknown fields in JSON object",
input: []byte(`{"name":"field1","unknown":"field","sensitive":false}`),
expected: Field{Name: "field1", Sensitive: boolPtr(false)},
expectErr: false,
},
{
name: "malformed JSON",
input: []byte(`{"name":"field1","sensitive":true`),
expected: Field{},
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var field Field
err := field.UnmarshalJSON(tt.input)
if (err != nil) != tt.expectErr {
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr)
}
if !reflect.DeepEqual(field, tt.expected) {
t.Errorf("UnmarshalJSON() = %v, expected %v", field, tt.expected)
}
})
}
}

func TestFieldsUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input []byte
expected Fields
expectErr bool
}{
{
name: "empty input",
input: nil,
expected: nil,
expectErr: false,
},
{
name: "nil pointer",
input: nil,
expected: nil,
expectErr: false,
},
{
name: "valid JSON array",
input: []byte(`[{"Name":"field1"},{"Name":"field2"}]`),
expected: Fields{{Name: "field1"}, {Name: "field2"}},
expectErr: false,
},
{
name: "single string input",
input: []byte(`"field1,field2,field3"`),
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}},
expectErr: false,
},
{
name: "trim spaces in single string input",
input: []byte(`"field1, field2 , field3 "`),
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}},
expectErr: false,
},
{
name: "invalid JSON array",
input: []byte(`[{"Name":"field1"},{"Name":field2}]`),
expected: nil,
expectErr: true,
},
{
name: "invalid single string",
input: []byte(`1234`),
expected: nil,
expectErr: true,
},
{
name: "empty array",
input: []byte(`[]`),
expected: Fields{},
expectErr: false,
},
{
name: "empty string",
input: []byte(`""`),
expected: nil,
expectErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var fields Fields
err := fields.UnmarshalJSON(tt.input)
if (err != nil) != tt.expectErr {
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr)
}
if !reflect.DeepEqual(fields, tt.expected) {
t.Errorf("UnmarshalJSON() = %v, expected %v", fields, tt.expected)
}
})
}
}

0 comments on commit e0fdb70

Please sign in to comment.