-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enhance: add field-level sensitivity for prompts
Additionally, each field can now also have a description. This change is made such that all existing tools will work. However, existing code will need to be updated to support the new types. Signed-off-by: Donnie Adams <donnie@acorn.io>
- Loading branch information
Showing
7 changed files
with
227 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,74 @@ | ||
package types | ||
|
||
import ( | ||
"encoding/json" | ||
"strings" | ||
) | ||
|
||
const ( | ||
PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL" | ||
PromptTokenEnvVar = "GPTSCRIPT_PROMPT_TOKEN" | ||
) | ||
|
||
type Prompt struct { | ||
Message string `json:"message,omitempty"` | ||
Fields []string `json:"fields,omitempty"` | ||
Fields Fields `json:"fields,omitempty"` | ||
Sensitive bool `json:"sensitive,omitempty"` | ||
Metadata map[string]string `json:"metadata,omitempty"` | ||
} | ||
|
||
type Field struct { | ||
Name string `json:"name,omitempty"` | ||
Sensitive *bool `json:"sensitive,omitempty"` | ||
Description string `json:"description,omitempty"` | ||
} | ||
|
||
type Fields []Field | ||
|
||
// UnmarshalJSON will unmarshal the corresponding JSON object for Fields, | ||
// or a comma-separated strings (for backwards compatibility). | ||
func (f *Fields) UnmarshalJSON(b []byte) error { | ||
if len(b) == 0 || f == nil { | ||
return nil | ||
} | ||
|
||
if b[0] == '[' { | ||
var arr []Field | ||
if err := json.Unmarshal(b, &arr); err != nil { | ||
return err | ||
} | ||
*f = arr | ||
return nil | ||
} | ||
|
||
var fields string | ||
if err := json.Unmarshal(b, &fields); err != nil { | ||
return err | ||
} | ||
|
||
if fields != "" { | ||
fieldsArr := strings.Split(fields, ",") | ||
*f = make([]Field, 0, len(fieldsArr)) | ||
for _, field := range fieldsArr { | ||
*f = append(*f, Field{Name: strings.TrimSpace(field)}) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
type field *Field | ||
|
||
// UnmarshalJSON will unmarshal the corresponding JSON object for a Field, | ||
// or a string (for backwards compatibility). | ||
func (f *Field) UnmarshalJSON(b []byte) error { | ||
if len(b) == 0 || f == nil { | ||
return nil | ||
} | ||
|
||
if b[0] == '{' { | ||
return json.Unmarshal(b, field(f)) | ||
} | ||
|
||
return json.Unmarshal(b, &f.Name) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package types | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestFieldUnmarshalJSON(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input []byte | ||
expected Field | ||
expectErr bool | ||
}{ | ||
{ | ||
name: "valid single Field object JSON", | ||
input: []byte(`{"name":"field1","sensitive":true,"description":"A test field"}`), | ||
expected: Field{Name: "field1", Sensitive: boolPtr(true), Description: "A test field"}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "valid Field name as string", | ||
input: []byte(`"field1"`), | ||
expected: Field{Name: "field1"}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "empty input", | ||
input: []byte(``), | ||
expected: Field{}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "invalid JSON object", | ||
input: []byte(`{"name":"field1","sensitive":"not_boolean"}`), | ||
expected: Field{Name: "field1", Sensitive: new(bool)}, | ||
expectErr: true, | ||
}, | ||
{ | ||
name: "extra unknown fields in JSON object", | ||
input: []byte(`{"name":"field1","unknown":"field","sensitive":false}`), | ||
expected: Field{Name: "field1", Sensitive: boolPtr(false)}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "malformed JSON", | ||
input: []byte(`{"name":"field1","sensitive":true`), | ||
expected: Field{}, | ||
expectErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
var field Field | ||
err := field.UnmarshalJSON(tt.input) | ||
if (err != nil) != tt.expectErr { | ||
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr) | ||
} | ||
if !reflect.DeepEqual(field, tt.expected) { | ||
t.Errorf("UnmarshalJSON() = %v, expected %v", field, tt.expected) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestFieldsUnmarshalJSON(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input []byte | ||
expected Fields | ||
expectErr bool | ||
}{ | ||
{ | ||
name: "empty input", | ||
input: nil, | ||
expected: nil, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "nil pointer", | ||
input: nil, | ||
expected: nil, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "valid JSON array", | ||
input: []byte(`[{"Name":"field1"},{"Name":"field2"}]`), | ||
expected: Fields{{Name: "field1"}, {Name: "field2"}}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "single string input", | ||
input: []byte(`"field1,field2,field3"`), | ||
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "trim spaces in single string input", | ||
input: []byte(`"field1, field2 , field3 "`), | ||
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "invalid JSON array", | ||
input: []byte(`[{"Name":"field1"},{"Name":field2}]`), | ||
expected: nil, | ||
expectErr: true, | ||
}, | ||
{ | ||
name: "invalid single string", | ||
input: []byte(`1234`), | ||
expected: nil, | ||
expectErr: true, | ||
}, | ||
{ | ||
name: "empty array", | ||
input: []byte(`[]`), | ||
expected: Fields{}, | ||
expectErr: false, | ||
}, | ||
{ | ||
name: "empty string", | ||
input: []byte(`""`), | ||
expected: nil, | ||
expectErr: false, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
var fields Fields | ||
err := fields.UnmarshalJSON(tt.input) | ||
if (err != nil) != tt.expectErr { | ||
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr) | ||
} | ||
if !reflect.DeepEqual(fields, tt.expected) { | ||
t.Errorf("UnmarshalJSON() = %v, expected %v", fields, tt.expected) | ||
} | ||
}) | ||
} | ||
} |