Skip to content

Commit

Permalink
Dynamically create jwks clusters for jwt-providers
Browse files Browse the repository at this point in the history
  • Loading branch information
roncodingenthusiast committed Jun 29, 2023
1 parent 85b78fe commit f3adf49
Show file tree
Hide file tree
Showing 16 changed files with 520 additions and 21 deletions.
94 changes: 94 additions & 0 deletions agent/xds/clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package xds
import (
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -141,6 +143,22 @@ func (s *ResourceGenerator) clustersFromSnapshotConnectProxy(cfgSnap *proxycfg.C
clusters = append(clusters, upstreamCluster)
}

// add clusters for jwt-providers
for _, prov := range cfgSnap.JWTProviders {
//skip cluster creation for local providers
if prov.JSONWebKeySet == nil || prov.JSONWebKeySet.Remote == nil {
continue
}

cluster, err := makeJWTProviderCluster(prov)
if err != nil {
s.Logger.Warn("failed to make jwt-provider cluster", "provider name", prov.Name, "error", err)
continue
}

clusters = append(clusters, cluster)
}

for _, u := range cfgSnap.Proxy.Upstreams {
if u.DestinationType != structs.UpstreamDestTypePreparedQuery {
continue
Expand Down Expand Up @@ -184,6 +202,82 @@ func (s *ResourceGenerator) clustersFromSnapshotConnectProxy(cfgSnap *proxycfg.C
return clusters, nil
}

func makeJWTProviderCluster(p *structs.JWTProviderConfigEntry) (*envoy_cluster_v3.Cluster, error) {
if p.JSONWebKeySet == nil || p.JSONWebKeySet.Remote == nil {
return nil, fmt.Errorf("cannot create JWKS cluster for non-remote JWKS. Provider Name: %s", p.Name)
}
hostname, scheme, port, err := parseJWTRemoteURL(p.JSONWebKeySet.Remote.URI)
if err != nil {
return nil, err
}

// TODO: expose additional fields: eg. ConnectTimeout, through
// JWTProviderConfigEntry to allow user to configure cluster
cluster := &envoy_cluster_v3.Cluster{
Name: makeJWKSClusterName(p.Name),
ClusterDiscoveryType: &envoy_cluster_v3.Cluster_Type{
Type: envoy_cluster_v3.Cluster_STRICT_DNS,
},
LoadAssignment: &envoy_endpoint_v3.ClusterLoadAssignment{
ClusterName: makeJWKSClusterName(p.Name),
Endpoints: []*envoy_endpoint_v3.LocalityLbEndpoints{
{
LbEndpoints: []*envoy_endpoint_v3.LbEndpoint{
makeEndpoint(hostname, port),
},
},
},
},
}

if scheme == "https" {
// TODO: expose this configuration through JWTProviderConfigEntry to allow
// user to configure certs
jwksTLSContext, err := makeUpstreamTLSTransportSocket(
&envoy_tls_v3.UpstreamTlsContext{
CommonTlsContext: &envoy_tls_v3.CommonTlsContext{
ValidationContextType: &envoy_tls_v3.CommonTlsContext_ValidationContext{
ValidationContext: &envoy_tls_v3.CertificateValidationContext{},
},
},
},
)
if err != nil {
return nil, err
}

cluster.TransportSocket = jwksTLSContext
}
return cluster, nil
}

// parseJWTRemoteURL splits the URI into domain, scheme and port.
// It will default to port 80 for http and 443 for https for any
// URI that does not specify a port.
func parseJWTRemoteURL(uri string) (string, string, int, error) {
u, err := url.ParseRequestURI(uri)
if err != nil {
return "", "", 0, err
}

var port int
if u.Port() != "" {
port, err = strconv.Atoi(u.Port())
if err != nil {
return "", "", port, err
}
}

if port == 0 {
port = 80
if u.Scheme == "https" {
port = 443
}
}

return u.Hostname(), u.Scheme, port, nil
}

func makeExposeClusterName(destinationPort int) string {
return fmt.Sprintf("exposed_cluster_%d", destinationPort)
}
Expand Down
179 changes: 179 additions & 0 deletions agent/xds/clusters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,185 @@ func TestEnvoyLBConfig_InjectToCluster(t *testing.T) {
}
}

func TestMakeJWTProviderCluster(t *testing.T) {
// All tests here depend on golden files located under: agent/xds/testdata/jwt_authn_cluster/*
tests := map[string]struct {
provider *structs.JWTProviderConfigEntry
expectedError string
}{
"remote-jwks-not-configured": {
provider: &structs.JWTProviderConfigEntry{
Kind: "jwt-provider",
Name: "okta",
JSONWebKeySet: &structs.JSONWebKeySet{},
},
expectedError: "cannot create JWKS cluster for non remote JWKS. Provider Name: okta",
},
"local-jwks-configured": {
provider: &structs.JWTProviderConfigEntry{
Kind: "jwt-provider",
Name: "okta",
JSONWebKeySet: &structs.JSONWebKeySet{
Local: &structs.LocalJWKS{
Filename: "filename",
},
},
},
expectedError: "cannot create JWKS cluster for non remote JWKS. Provider Name: okta",
},
"https-provider-with-hostname-no-port": {
provider: makeTestProviderWithJWKS("https://example-okta.com/.well-known/jwks.json"),
},
"http-provider-with-hostname-no-port": {
provider: makeTestProviderWithJWKS("http://example-okta.com/.well-known/jwks.json"),
},
"https-provider-with-hostname-and-port": {
provider: makeTestProviderWithJWKS("https://example-okta.com:90/.well-known/jwks.json"),
},
"http-provider-with-hostname-and-port": {
provider: makeTestProviderWithJWKS("http://example-okta.com:90/.well-known/jwks.json"),
},
"https-provider-with-ip-no-port": {
provider: makeTestProviderWithJWKS("https://127.0.0.1"),
},
"http-provider-with-ip-no-port": {
provider: makeTestProviderWithJWKS("http://127.0.0.1"),
},
"https-provider-with-ip-and-port": {
provider: makeTestProviderWithJWKS("https://127.0.0.1:9091"),
},
"http-provider-with-ip-and-port": {
provider: makeTestProviderWithJWKS("http://127.0.0.1:9091"),
},
}

for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
cluster, err := makeJWTProviderCluster(tt.provider)
if tt.expectedError != "" {
require.Error(t, err, tt.expectedError)
} else {
require.NoError(t, err)
gotJSON := protoToJSON(t, cluster)
require.JSONEq(t, goldenSimple(t, filepath.Join("jwt_authn_clusters", name), gotJSON), gotJSON)
}

})
}
}

func makeTestProviderWithJWKS(uri string) *structs.JWTProviderConfigEntry {
return &structs.JWTProviderConfigEntry{
Kind: "jwt-provider",
Name: "okta",
Issuer: "test-issuer",
JSONWebKeySet: &structs.JSONWebKeySet{
Remote: &structs.RemoteJWKS{
RequestTimeoutMs: 1000,
FetchAsynchronously: true,
URI: uri,
},
},
}
}

func TestParseJWTRemoteURL(t *testing.T) {
tests := map[string]struct {
uri string
expectedHost string
expectedPort int
expectedScheme string
expectError bool
}{
"invalid-url": {
uri: ".com",
expectError: true,
},
"https-hostname-no-port": {
uri: "https://test.test.com",
expectedHost: "test.test.com",
expectedPort: 443,
expectedScheme: "https",
},
"https-hostname-with-port": {
uri: "https://test.test.com:4545",
expectedHost: "test.test.com",
expectedPort: 4545,
expectedScheme: "https",
},
"https-hostname-with-port-and-path": {
uri: "https://test.test.com:4545/test",
expectedHost: "test.test.com",
expectedPort: 4545,
expectedScheme: "https",
},
"http-hostname-no-port": {
uri: "http://test.test.com",
expectedHost: "test.test.com",
expectedPort: 80,
expectedScheme: "http",
},
"http-hostname-with-port": {
uri: "http://test.test.com:4636",
expectedHost: "test.test.com",
expectedPort: 4636,
expectedScheme: "http",
},
"https-ip-no-port": {
uri: "https://127.0.0.1",
expectedHost: "127.0.0.1",
expectedPort: 443,
expectedScheme: "https",
},
"https-ip-with-port": {
uri: "https://127.0.0.1:3434",
expectedHost: "127.0.0.1",
expectedPort: 3434,
expectedScheme: "https",
},
"http-ip-no-port": {
uri: "http://127.0.0.1",
expectedHost: "127.0.0.1",
expectedPort: 80,
expectedScheme: "http",
},
"http-ip-with-port": {
uri: "http://127.0.0.1:9190",
expectedHost: "127.0.0.1",
expectedPort: 9190,
expectedScheme: "http",
},
"http-ip-with-port-and-path": {
uri: "http://127.0.0.1:9190/some/where",
expectedHost: "127.0.0.1",
expectedPort: 9190,
expectedScheme: "http",
},
"http-ip-no-port-with-path": {
uri: "http://127.0.0.1/test/path",
expectedHost: "127.0.0.1",
expectedPort: 80,
expectedScheme: "http",
},
}

for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
host, scheme, port, err := parseJWTRemoteURL(tt.uri)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, host, tt.expectedHost)
require.Equal(t, scheme, tt.expectedScheme)
require.Equal(t, port, tt.expectedPort)
}
})
}
}

// UID is just a convenience function to aid in writing tests less verbosely.
func UID(input string) proxycfg.UpstreamID {
return proxycfg.UpstreamIDFromString(input)
Expand Down
15 changes: 9 additions & 6 deletions agent/xds/jwt_authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
const (
jwtEnvoyFilter = "envoy.filters.http.jwt_authn"
jwtMetadataKeyPrefix = "jwt_payload"
jwksClusterPrefix = "jwks_cluster"
)

// This is an intermediate JWTProvider form used to associate
Expand Down Expand Up @@ -158,7 +159,7 @@ func buildJWTProviderConfig(p *structs.JWTProviderConfigEntry, metadataKeySuffix
}
envoyCfg.JwksSourceSpecifier = specifier
} else if remote := p.JSONWebKeySet.Remote; remote != nil && remote.URI != "" {
envoyCfg.JwksSourceSpecifier = makeRemoteJWKS(remote)
envoyCfg.JwksSourceSpecifier = makeRemoteJWKS(remote, p.Name)
} else {
return nil, fmt.Errorf("invalid jwt provider config; missing JSONWebKeySet for provider: %s", p.Name)
}
Expand Down Expand Up @@ -210,14 +211,12 @@ func makeLocalJWKS(l *structs.LocalJWKS, pName string) (*envoy_http_jwt_authn_v3
return specifier, nil
}

func makeRemoteJWKS(r *structs.RemoteJWKS) *envoy_http_jwt_authn_v3.JwtProvider_RemoteJwks {
func makeRemoteJWKS(r *structs.RemoteJWKS, providerName string) *envoy_http_jwt_authn_v3.JwtProvider_RemoteJwks {
remote_specifier := envoy_http_jwt_authn_v3.JwtProvider_RemoteJwks{
RemoteJwks: &envoy_http_jwt_authn_v3.RemoteJwks{
HttpUri: &envoy_core_v3.HttpUri{
Uri: r.URI,
// TODO(roncodingenthusiast): An explicit cluster is required.
// Need to figure out replacing `jwks_cluster` will an actual cluster
HttpUpstreamType: &envoy_core_v3.HttpUri_Cluster{Cluster: "jwks_cluster"},
Uri: r.URI,
HttpUpstreamType: &envoy_core_v3.HttpUri_Cluster{Cluster: makeJWKSClusterName(providerName)},
},
AsyncFetch: &envoy_http_jwt_authn_v3.JwksAsyncFetch{
FastListener: r.FetchAsynchronously,
Expand All @@ -239,6 +238,10 @@ func makeRemoteJWKS(r *structs.RemoteJWKS) *envoy_http_jwt_authn_v3.JwtProvider_
return &remote_specifier
}

func makeJWKSClusterName(providerName string) string {
return fmt.Sprintf("%s_%s", jwksClusterPrefix, providerName)
}

func buildJWTRetryPolicy(r *structs.JWKSRetryPolicy) *envoy_core_v3.RetryPolicy {
var pol envoy_core_v3.RetryPolicy
if r == nil {
Expand Down
Loading

0 comments on commit f3adf49

Please sign in to comment.