Skip to content

Commit

Permalink
internal/{llmapp,gaby}: display policy evaluation results in web UI
Browse files Browse the repository at this point in the history
When policies are enforced (-enforcepolicy), display policy evaluation
results on the Overview web UI page.

To support this, return the detailed results from the llmapp.Overview
functions.

For #70

Change-Id: I5642236f48d205be25f7359b540a72ff86903e4f
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637979
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
  • Loading branch information
tatianab committed Dec 20, 2024
1 parent 5d12b58 commit b08abec
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 41 deletions.
16 changes: 16 additions & 0 deletions internal/gaby/tmpl/overviewpage.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ function togglePrompt() {
</script>
{{end}}

{{define "show-policy"}}
{{with .Raw.PolicyEvaluation}}
<div class="toggle" onclick="togglePolicy()">[show policy evaluation]</div>
<div id="policy" class="start-hidden">
<pre>{{.Display}}</pre>
</div>
{{end}}
<script>
function togglePolicy() {
var x = document.getElementById("policy");
toggle(x)
}
</script>
{{end}}

{{define "overview-result"}}
<div class="section" id="result">
{{- with .Error -}}
Expand All @@ -60,6 +75,7 @@ function togglePrompt() {
</div>
{{template "show-rawoutput" .}}
{{template "show-prompt" .}}
{{template "show-policy" .}}
{{- else }}
{{if .Params.Query}}<p>No result.</p>{{end}}
{{- end}}
Expand Down
42 changes: 23 additions & 19 deletions internal/llmapp/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package llmapp

import (
"context"
"fmt"
"log/slog"

"golang.org/x/oscar/internal/llm"
Expand All @@ -24,46 +25,49 @@ func NewWithChecker(lg *slog.Logger, g llm.ContentGenerator, checker llm.PolicyC
return &Client{slog: lg, g: g, checker: checker, db: db}
}

// hasPolicyViolation invokes the policy checker on the given prompts and LLM output and
// logs its results. It reports whether any policy violations were found.
// evaluatePolicy invokes the policy checker on the given prompts and LLM output and
// wraps its results as a [*PolicyEvaluation].
// TODO(tatianabradley): Cache calls to policy checker.
func (c *Client) hasPolicyViolation(ctx context.Context, prompts []llm.Part, output string) bool {
func (c *Client) evaluatePolicy(ctx context.Context, prompts []llm.Part, output string) *PolicyEvaluation {
if c.checker == nil {
return false
return nil
}
foundViolation := false
pe := &PolicyEvaluation{}
for _, p := range prompts {
switch v := p.(type) {
case llm.Text:
if c.logCheck(ctx, string(v), nil) {
foundViolation = true
r := c.check(ctx, string(v), nil)
if len(r.Violations) > 0 {
pe.Violative = true
}
pe.PromptResults = append(pe.PromptResults, r)
default:
// Other types are not supported for checks yet.
c.slog.Info("llmapp: can't check policy for prompt part (unsupported type)", "prompt part", v)
err := fmt.Errorf("llmapp: can't check policy for prompt part (unsupported type %T", v)
pe.PromptResults = append(pe.PromptResults, &PolicyResult{Text: "unknown", Error: err})
}
}
if c.logCheck(ctx, output, prompts) {
return true
r := c.check(ctx, output, prompts)
if len(r.Violations) > 0 {
pe.Violative = true
}
return foundViolation
pe.OutputResults = r
return pe
}

// logCheck invokes the policy checker on the give text (with optional prompts)
// and logs its results.
// It reports whether any policy violations were found.
func (c *Client) logCheck(ctx context.Context, text string, prompts []llm.Part) bool {
// check invokes the policy checker on the given text (with optional prompts)
// and returns its results.
func (c *Client) check(ctx context.Context, text string, prompts []llm.Part) *PolicyResult {
prs, err := c.checker.CheckText(ctx, text, prompts...)
if err != nil {
c.slog.Error("llmapp: error checking for policy violations", "err", err)
return false
return &PolicyResult{Text: text, Error: fmt.Errorf("llmapp: error while checking for policy violations: %w", err)}
}
c.slog.Info("llmapp: found policy results", "text", text, "prompts", prompts, "results", toStrings(prs))
if vs := violations(prs); len(vs) > 0 {
c.slog.Warn("llmapp: found policy violations for LLM output", "text", text, "prompts", prompts, "violations", toStrings(vs))
return true
return &PolicyResult{Text: text, Results: prs, Violations: vs}
}
return false
return &PolicyResult{Text: text, Results: prs}
}

func toStrings(prs []*llm.PolicyResult) []string {
Expand Down
66 changes: 56 additions & 10 deletions internal/llmapp/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
Expand All @@ -28,18 +30,55 @@ func TestWithChecker(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !r.HasPolicyViolation {
if !r.HasPolicyViolation() {
t.Errorf("c.Overview.HasPolicyViolation = false, want true")
}
want := &PolicyEvaluation{
Violative: true,
PromptResults: []*PolicyResult{
// doc1
{
Results: []*llm.PolicyResult{violationResult},
Violations: []*llm.PolicyResult{violationResult},
},
// doc2
{Results: []*llm.PolicyResult{okResult}},
// instructions
{Results: []*llm.PolicyResult{okResult}},
},
OutputResults: &PolicyResult{
Results: []*llm.PolicyResult{violationResult},
Violations: []*llm.PolicyResult{violationResult},
},
}
if diff := cmp.Diff(want, r.PolicyEvaluation, cmpopts.IgnoreFields(PolicyResult{}, "Text")); diff != "" {
t.Errorf("c.Overview.PolicyEvaluation mismatch (-want,+got):\n%v", diff)
}

// Without violation.
r, err = c.Overview(context.Background(), doc2)
if err != nil {
t.Fatal(err)
}
if r.HasPolicyViolation {
if r.HasPolicyViolation() {
t.Errorf("c.Overview.HasPolicyViolation = true, want false")
}

want = &PolicyEvaluation{
Violative: false,
PromptResults: []*PolicyResult{
// doc2
{Results: []*llm.PolicyResult{okResult}},
// instructions
{Results: []*llm.PolicyResult{okResult}},
},
OutputResults: &PolicyResult{
Results: []*llm.PolicyResult{okResult},
},
}
if diff := cmp.Diff(want, r.PolicyEvaluation, cmpopts.IgnoreFields(PolicyResult{}, "Text")); diff != "" {
t.Errorf("c.Overview.PolicyEvaluation mismatch (-want,+got):\n%v", diff)
}
}

// badChecker is a test implementation of [llm.PolicyChecker] that
Expand All @@ -50,20 +89,27 @@ type badChecker struct{}
// no-op
func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {}

var (
violationResult = &llm.PolicyResult{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultViolative,
Score: 1,
}
okResult = &llm.PolicyResult{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultNonViolative,
Score: 0,
}
)

// return violation for text containing "bad" and no violation for any other text.
func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) {
if strings.Contains(text, "bad") {
return []*llm.PolicyResult{
{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultViolative,
},
violationResult,
}, nil
}
return []*llm.PolicyResult{
{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultNonViolative,
},
okResult,
}, nil
}
74 changes: 67 additions & 7 deletions internal/llmapp/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

package llmapp

import "golang.org/x/oscar/internal/llm"
import (
"fmt"
"strings"

"golang.org/x/oscar/internal/llm"
)

// A Doc is a document to provide to an LLM as part of a prompt.
type Doc struct {
Expand All @@ -23,10 +28,65 @@ type Doc struct {

// Result is the result of an LLM call.
type Result struct {
Response string // the raw LLM-generated response
Cached bool // whether the response was cached
Schema *llm.Schema // the JSON schema used to generate the result (nil if none)
Prompt []llm.Part // the prompt(s) used to generate the result
// TODO(tatianabradley): Store the specific policy results instead of just a boolean.
HasPolicyViolation bool // whether any policy violations were found for the inputs or outputs of the LLM
Response string // the raw LLM-generated response
Cached bool // whether the response was cached
Schema *llm.Schema // the JSON schema used to generate the result (nil if none)
Prompt []llm.Part // the prompt(s) used to generate the result
PolicyEvaluation *PolicyEvaluation // (if a policy checker is configured) the policy evaluation result
}

// A PolicyEvaluation is the result of evaluating a policy against
// a multi-part prompt and an output of an LLM.
type PolicyEvaluation struct {
Violative bool // whether any violations were found
PromptResults []*PolicyResult
OutputResults *PolicyResult
}

// String returns a human readable representation of a policy evaluation.
func (pe *PolicyEvaluation) String() string {
if pe == nil {
return ""
}
b := strings.Builder{}
b.WriteString(fmt.Sprintf("Violative: %t\n\n", pe.Violative))
b.WriteString("Prompt Results:\n")
for _, pr := range pe.PromptResults {
b.WriteString(pr.String() + "\n")
}
b.WriteString("Output Results:\n")
b.WriteString(pe.OutputResults.String() + "\n")
return b.String()
}

// A PolicyResult is the result of evaluating a policy against
// an input or output to an LLM.
type PolicyResult struct {
Text string // the text that was analyzed
Results []*llm.PolicyResult
Violations []*llm.PolicyResult
Error error
}

// String returns a human readable representation of a policy result.
func (pr *PolicyResult) String() string {
if pr == nil {
return ""
}
b := strings.Builder{}
b.WriteString(fmt.Sprintf("Text: %s\n", pr.Text))
b.WriteString(fmt.Sprintf("Results: %v\n", pr.Results))
if len(pr.Violations) > 0 {
b.WriteString(fmt.Sprintf("Violations: %v\n", pr.Violations))
}
if pr.Error != nil {
b.WriteString(fmt.Sprintf("Error: %v\n", pr.Error))
}
return b.String()
}

// HasPolicyViolation reports whether the result or its prompts
// have any policy violations.
func (r *Result) HasPolicyViolation() bool {
return r.PolicyEvaluation != nil && r.PolicyEvaluation.Violative
}
10 changes: 5 additions & 5 deletions internal/llmapp/overview.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ func (c *Client) overview(ctx context.Context, kind docsKind, groups ...*docGrou
return nil, err
}
return &Result{
Response: overview,
Cached: cached,
Schema: schema,
Prompt: prompt,
HasPolicyViolation: c.hasPolicyViolation(ctx, prompt, overview),
Response: overview,
Cached: cached,
Schema: schema,
Prompt: prompt,
PolicyEvaluation: c.evaluatePolicy(ctx, prompt, overview),
}, nil
}

Expand Down

0 comments on commit b08abec

Please sign in to comment.