Skip to content

Commit

Permalink
refact: use protoreflect to convert check request v3 (#631)
Browse files Browse the repository at this point in the history
Avoids a roundtrip through JSON that was previously done for converting
the check request. Nothing changes for v2 check requests.

Signed-off-by: Anthony Regeda <regedaster@gmail.com>
  • Loading branch information
regeda authored Jan 8, 2025
1 parent e2e9622 commit bec106e
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 11 deletions.
65 changes: 65 additions & 0 deletions envoyauth/protomap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package envoyauth

import (
"google.golang.org/protobuf/reflect/protoreflect"
)

type valueResolver func(protoreflect.Value) any

func messageResolver(v protoreflect.Value) any {
return protomap(v.Message())
}

func interfaceResolver(v protoreflect.Value) any {
return v.Interface()
}

func chooseResolver(k protoreflect.Kind) valueResolver {
if k == protoreflect.MessageKind {
return messageResolver
}
return interfaceResolver
}

// protomap converts protobuf message into map[string]any type using json names.
func protomap(msg protoreflect.Message) map[string]any {
v := msg.Interface()
// handle structpb.Struct
if mapper, ok := v.(interface{ AsMap() map[string]any }); ok {
return mapper.AsMap()
}

result := make(map[string]any, msg.Descriptor().Fields().Len())

msg.Range(func(fd protoreflect.FieldDescriptor, value protoreflect.Value) bool {
name := fd.JSONName()

switch {
case fd.IsMap():
mapValue := value.Map()
mapResult := make(map[string]any, mapValue.Len())
valResolver := chooseResolver(fd.MapValue().Kind())
mapValue.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
mapResult[key.String()] = valResolver(val)
return true
})
result[name] = mapResult
case fd.IsList():
list := value.List()
listResult := make([]any, list.Len())
valResolver := chooseResolver(fd.Kind())
for i := 0; i < list.Len(); i++ {
listResult[i] = valResolver(list.Get(i))
}
result[name] = listResult
case fd.Kind() == protoreflect.MessageKind:
result[name] = protomap(value.Message())
default:
result[name] = value.Interface()
}

return true
})

return result
}
130 changes: 130 additions & 0 deletions envoyauth/protomap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package envoyauth

import (
"testing"

ext_authz_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"google.golang.org/protobuf/encoding/protojson"
)

const extAuthzRequest = `{
"attributes": {
"source": {
"address": {
"socketAddress": {
"address": "127.0.0.1"
}
},
"service": "dummy",
"labels": {
"foo": "bar"
}
},
"metadataContext": {
"filterMetadata": {
"dummy": {
"hello": "world",
"count": 1
}
}
},
"contextExtensions": {
"hello": "world"
},
"request": {
"http": {
"id": "13359530607844510314",
"method": "GET",
"headers": {
":authority": "192.168.99.100:31380",
":method": "GET",
":path": "/api/v1/products",
"accept": "*/*"
},
"path": "/api/v1/products",
"host": "192.168.99.100:31380",
"protocol": "HTTP/1.1",
"body": "{\"firstname\": \"foo\", \"lastname\": \"bar\"}"
}
}
}
}`

func Test_protomap(t *testing.T) {
var req ext_authz_v3.CheckRequest

if err := protojson.Unmarshal([]byte(extAuthzRequest), &req); err != nil {
t.Fatal(err)
}

result := protomap(req.ProtoReflect())

if result == nil {
t.Fatal("not nil expected")
}

assertMap(t, result, map[string]any{
"attributes": map[string]any{
"source": map[string]any{
"service": "dummy",
"labels": map[string]any{
"foo": "bar",
},
"address": map[string]any{
"socketAddress": map[string]any{
"address": "127.0.0.1",
},
},
},
"metadataContext": map[string]any{
"filterMetadata": map[string]any{
"dummy": map[string]any{
"hello": "world",
"count": float64(1),
},
},
},
"contextExtensions": map[string]any{
"hello": "world",
},
"request": map[string]any{
"http": map[string]any{
"id": "13359530607844510314",
"method": "GET",
"path": "/api/v1/products",
"host": "192.168.99.100:31380",
"protocol": "HTTP/1.1",
"body": "{\"firstname\": \"foo\", \"lastname\": \"bar\"}",
"headers": map[string]any{
":authority": "192.168.99.100:31380",
":method": "GET",
":path": "/api/v1/products",
"accept": "*/*",
},
},
},
},
})
}

func assertMap(t *testing.T, actual map[string]any, expected map[string]any) {
t.Helper()
if len(actual) != len(expected) {
t.Fatalf("different len of maps, actual %v, expected %v", actual, expected)
}
for k, ev := range expected {
av, ok := actual[k]
if !ok {
t.Fatalf("expected key %s not found", k)
}
if em, ok := ev.(map[string]any); ok {
am, ok := av.(map[string]any)
if !ok {
t.Fatalf("both values must be map[string]any, actual %T", av)
}
assertMap(t, em, am)
} else if ev != av {
t.Fatalf("values of key %s are different, actual %v (%[2]T), expected %v (%[3]T)", k, av, ev)
}
}
}
18 changes: 7 additions & 11 deletions envoyauth/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregis
var err error
var input map[string]interface{}

var bs, rawBody []byte
var rawBody []byte
var path, body string
var headers, version map[string]string

Expand All @@ -41,18 +41,18 @@ func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregis
// etc -- we only care for its JSON representation as fed into evaluation later.
switch req := req.(type) {
case *ext_authz_v3.CheckRequest:
bs, err = protojson.Marshal(req)
if err != nil {
return nil, err
}
input = protomap(req.ProtoReflect())
path = req.GetAttributes().GetRequest().GetHttp().GetPath()
body = req.GetAttributes().GetRequest().GetHttp().GetBody()
headers = req.GetAttributes().GetRequest().GetHttp().GetHeaders()
rawBody = req.GetAttributes().GetRequest().GetHttp().GetRawBody()
version = v3Info
case *ext_authz_v2.CheckRequest:
bs, err = json.Marshal(req)
if err != nil {
var bs []byte
if bs, err = json.Marshal(req); err != nil {
return nil, err
}
if err = util.UnmarshalJSON(bs, &input); err != nil {
return nil, err
}
path = req.GetAttributes().GetRequest().GetHttp().GetPath()
Expand All @@ -61,10 +61,6 @@ func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregis
version = v2Info
}

err = util.UnmarshalJSON(bs, &input)
if err != nil {
return nil, err
}
input["version"] = version

parsedPath, parsedQuery, err := getParsedPathAndQuery(path)
Expand Down

0 comments on commit bec106e

Please sign in to comment.