From 0770dbb4b6a422bc93702cf04eb571a5fd0dbd01 Mon Sep 17 00:00:00 2001 From: Alex Palesandro Date: Thu, 9 Nov 2023 17:33:39 +0100 Subject: [PATCH] feat: add connector webhook logic --- cmd/dex/config.go | 4 + cmd/dex/serve.go | 1 + go.mod | 3 +- go.sum | 6 +- pkg/webhook/config/consts.go | 7 + pkg/webhook/connectors/connectors.go | 74 ++++++++++ pkg/webhook/connectors/connectors_test.go | 111 ++++++++++++++ pkg/webhook/connectors/helpers.go | 55 +++++++ pkg/webhook/connectors/helpers_test.go | 169 ++++++++++++++++++++++ pkg/webhook/connectors/types.go | 48 ++++++ server/handlers.go | 45 ++++++ server/server.go | 18 +++ 12 files changed, 538 insertions(+), 3 deletions(-) create mode 100644 pkg/webhook/config/consts.go create mode 100644 pkg/webhook/connectors/connectors.go create mode 100644 pkg/webhook/connectors/connectors_test.go create mode 100644 pkg/webhook/connectors/helpers.go create mode 100644 pkg/webhook/connectors/helpers_test.go create mode 100644 pkg/webhook/connectors/types.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 831156fd40..2aae9dc61e 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -11,6 +11,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/dexidp/dex/pkg/log" + "github.com/dexidp/dex/pkg/webhook/config" "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/ent" @@ -49,6 +50,9 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // ConnectorFilterHooks is a list of hooks that can be used to filter the connectors` + ConnectorFilterHooks config.ConnectorFilterHooks `json:"connectorFiltersHooks"` } // Validate the configuration diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 47b090aeab..c009a8a822 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -272,6 +272,7 @@ func runServe(options serveOptions) error { Now: now, PrometheusRegistry: prometheusRegistry, HealthChecker: healthChecker, + ConnectorFilterHooks: c.ConnectorFilterHooks, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/go.mod b/go.mod index c5ebb2702e..dd16650307 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/stretchr/testify v1.8.4 go.etcd.io/etcd/client/pkg/v3 v3.5.9 go.etcd.io/etcd/client/v3 v3.5.9 + go.uber.org/mock v0.3.0 golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 golang.org/x/net v0.17.0 @@ -88,7 +89,7 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.17.0 // indirect - golang.org/x/mod v0.10.0 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.8 // indirect diff --git a/go.sum b/go.sum index 5026e7c0f1..e8d62d3d97 100644 --- a/go.sum +++ b/go.sum @@ -228,6 +228,8 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= @@ -251,8 +253,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/pkg/webhook/config/consts.go b/pkg/webhook/config/consts.go new file mode 100644 index 0000000000..c641c55710 --- /dev/null +++ b/pkg/webhook/config/consts.go @@ -0,0 +1,7 @@ +package config + +type HookType string + +const ( + External HookType = "external" +) diff --git a/pkg/webhook/connectors/connectors.go b/pkg/webhook/connectors/connectors.go new file mode 100644 index 0000000000..e531b2186d --- /dev/null +++ b/pkg/webhook/connectors/connectors.go @@ -0,0 +1,74 @@ +//go:generate go run -mod mod go.uber.org/mock/mockgen -destination=./mocks/mock_caller.go -package=connectors --source=types.go FilterCaller +package connectors + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/pkg/webhook/helpers" + "github.com/dexidp/dex/storage" +) + +func NewConnectorFilter(hook *config.ConnectorFilterHook) (*ConnectorFilterHook, error) { + var hookInvoker FilterCaller + switch hook.Type { + case config.External: + h, err := helpers.NewWebhookHTTPHelpers(hook.Config) + if err != nil { + return nil, fmt.Errorf("could not create webhook http helpers: %w", err) + } + hookInvoker = NewFilterCaller(h, hook.RequestScope) + default: + return nil, fmt.Errorf("unknown type: %s", hook.Type) + } + return &ConnectorFilterHook{ + Name: hook.Name, + FilterInvoker: hookInvoker, + }, nil +} + +func (f WebhookFilterCaller) callHook(connectors []ConnectorContext, req RequestContext) ([]ConnectorContext, error) { + toMarshal := FilterWebhookPayload{ + Connectors: connectors, + Request: req, + } + + payload, err := json.Marshal(toMarshal) + if err != nil { + return nil, fmt.Errorf("could not serialize claims: %w", err) + } + + body, err := f.transportHelper.CallWebhook(payload) + if err != nil { + return nil, fmt.Errorf("could not call webhook: %w", err) + } + var res []ConnectorContext + + if err := json.Unmarshal(body, &res); err != nil { + return nil, fmt.Errorf("could not unmarshal response: %w", err) + } + + return res, nil +} + +func NewFilterCaller(h helpers.WebhookHTTPHelpers, + context *config.HookRequestScope, +) *WebhookFilterCaller { + return &WebhookFilterCaller{ + RequestScope: context, + transportHelper: h, + } +} + +func (f WebhookFilterCaller) CallHook(connectors []storage.Connector, + r *http.Request, +) ([]storage.Connector, error) { + payload := createConnectorWebhookPayload(f.RequestScope, connectors, r) + filteredConnectors, err := f.callHook(payload.Connectors, payload.Request) + if err != nil { + return nil, err + } + return unwrapConnectorWebhookPayload(filteredConnectors, connectors), nil +} diff --git a/pkg/webhook/connectors/connectors_test.go b/pkg/webhook/connectors/connectors_test.go new file mode 100644 index 0000000000..10e802f357 --- /dev/null +++ b/pkg/webhook/connectors/connectors_test.go @@ -0,0 +1,111 @@ +package connectors + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/pkg/webhook/helpers" + "github.com/dexidp/dex/storage" +) + +func TestNewConnectorFilter(t *testing.T) { + d, err := NewConnectorFilter(&config.ConnectorFilterHook{ + Name: "test", + Type: config.External, + RequestScope: &config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }, + Config: &config.WebhookConfig{ + URL: "http://test.com", + InsecureSkipVerify: true, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, d) + assert.Equal(t, d.Name, "test") + assert.IsType(t, d.FilterInvoker, &WebhookFilterCaller{}) +} + +func TestNewConnectorFilter_UnknownType(t *testing.T) { + d, err := NewConnectorFilter(&config.ConnectorFilterHook{ + Name: "test", + Type: "unknown", + RequestScope: &config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }, + Config: &config.WebhookConfig{ + URL: "http://test.com", + InsecureSkipVerify: true, + }, + }) + assert.Error(t, err) + assert.Nil(t, d) +} + +func TestNewFilterCaller(t *testing.T) { + h, err := helpers.NewWebhookHTTPHelpers(&config.WebhookConfig{ + URL: "http://test.com", + InsecureSkipVerify: true, + }) + assert.NoError(t, err) + d := NewFilterCaller(h, &config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }) + assert.NotNil(t, d) + assert.Equal(t, h, d.transportHelper) + assert.Equal(t, d.RequestScope.Headers, []string{"header1", "header2"}) + assert.Equal(t, d.RequestScope.Params, []string{"param1", "param2"}) +} + +func TestCallHook_Logic(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + h := helpers.NewMockWebhookHTTPHelpers(ctrl) + h.EXPECT().CallWebhook(gomock.Any()).Return([]byte(`[{"id": "test", "type": "test", "name": "test"}]`), nil) + d := NewFilterCaller(h, &config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }) + connectorList, err := d.CallHook([]storage.Connector{ + { + ID: "test", + Type: "test", + Name: "test", + }, + }, &http.Request{}) + assert.NoError(t, err) + assert.Equal(t, connectorList, []storage.Connector{ + { + ID: "test", + Type: "test", + Name: "test", + }, + }) +} + +func TestCallHook_Logic_Error(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + h := helpers.NewMockWebhookHTTPHelpers(ctrl) + h.EXPECT().CallWebhook(gomock.Any()).Return(nil, assert.AnError) + d := NewFilterCaller(h, &config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }) + connectorList, err := d.CallHook([]storage.Connector{ + { + ID: "test", + Type: "test", + Name: "test", + }, + }, &http.Request{}) + assert.Error(t, err) + assert.Nil(t, connectorList) +} diff --git a/pkg/webhook/connectors/helpers.go b/pkg/webhook/connectors/helpers.go new file mode 100644 index 0000000000..277c0215f4 --- /dev/null +++ b/pkg/webhook/connectors/helpers.go @@ -0,0 +1,55 @@ +package connectors + +import ( + "net/http" + + "golang.org/x/exp/slices" + + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/storage" +) + +func createConnectorWebhookPayload(requestScope *config.HookRequestScope, connectors []storage.Connector, + r *http.Request, +) *FilterWebhookPayload { + payload := &FilterWebhookPayload{ + Connectors: []ConnectorContext{}, + Request: RequestContext{}, + } + for _, c := range connectors { + payload.Connectors = append(payload.Connectors, ConnectorContext{ + ID: c.ID, + Type: c.Type, + Name: c.Name, + }) + } + payload.Request.Params = make(map[string][]string) + if r != nil && r.URL != nil { + for k, v := range r.URL.Query() { + if slices.Contains(requestScope.Params, k) { + payload.Request.Params[k] = v + } + } + } + payload.Request.Headers = make(map[string][]string) + for k, v := range r.Header { + if slices.Contains(requestScope.Headers, k) { + payload.Request.Headers[k] = v + } + } + return payload +} + +func unwrapConnectorWebhookPayload(filteredConnectors []ConnectorContext, + initialConnectors []storage.Connector, +) []storage.Connector { + mappedConnectors := make([]storage.Connector, 0) + for _, filteredConnector := range filteredConnectors { + for _, initialConnector := range initialConnectors { + if filteredConnector.ID == initialConnector.ID { + mappedConnectors = append(mappedConnectors, initialConnector) + } + } + } + return mappedConnectors +} diff --git a/pkg/webhook/connectors/helpers_test.go b/pkg/webhook/connectors/helpers_test.go new file mode 100644 index 0000000000..3899ecc73f --- /dev/null +++ b/pkg/webhook/connectors/helpers_test.go @@ -0,0 +1,169 @@ +package connectors + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/storage" +) + +func Test_CreatePayloadEmpty(t *testing.T) { + payload := createConnectorWebhookPayload(&config.HookRequestScope{}, []storage.Connector{}, &http.Request{}) + assert.Equal(t, payload, &FilterWebhookPayload{ + Connectors: []ConnectorContext{}, + Request: RequestContext{ + Headers: map[string][]string{}, + Params: map[string][]string{}, + }, + }) +} + +func Test_CreatePayload_ScopeValidation(t *testing.T) { + payload := createConnectorWebhookPayload(&config.HookRequestScope{ + Headers: []string{"header1", "header2"}, + Params: []string{"param1", "param2"}, + }, []storage.Connector{}, &http.Request{ + Header: map[string][]string{ + "header1": {"value1"}, + "header2": {"value2"}, + "header3": {"value3"}, + }, + URL: &url.URL{ + RawQuery: "param1=value1¶m2=value2¶m3=value3", + }, + }) + assert.Equal(t, payload, &FilterWebhookPayload{ + Connectors: []ConnectorContext{}, + Request: RequestContext{ + Headers: map[string][]string{ + "header1": {"value1"}, + "header2": {"value2"}, + }, + Params: map[string][]string{ + "param1": {"value1"}, + "param2": {"value2"}, + }, + }, + }) +} + +func Test_CreatePayload_ScopeConnectorsValidation(t *testing.T) { + payload := createConnectorWebhookPayload(&config.HookRequestScope{}, []storage.Connector{ + {ID: "test1", Type: "ldap", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }, &http.Request{}) + assert.Equal(t, payload, &FilterWebhookPayload{ + Request: RequestContext{ + Headers: map[string][]string{}, + Params: map[string][]string{}, + }, + Connectors: []ConnectorContext{ + {ID: "test1", Type: "ldap", Name: "test1"}, + }, + }) +} + +func Test_MultipleConnectorValidation(t *testing.T) { + payload := createConnectorWebhookPayload(&config.HookRequestScope{}, []storage.Connector{ + {ID: "test1", Type: "ldap", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "ldap", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }, &http.Request{}) + assert.Equal(t, payload, &FilterWebhookPayload{ + Request: RequestContext{ + Headers: map[string][]string{}, + Params: map[string][]string{}, + }, + Connectors: []ConnectorContext{ + {ID: "test1", Type: "ldap", Name: "test1"}, + {ID: "test2", Type: "ldap", Name: "test2"}, + }, + }) +} + +func Test_UnwrapConnectorWebhookPayload_Empty(t *testing.T) { + filteredConnectors := []ConnectorContext{} + initialConnectors := []storage.Connector{} + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{}) +} + +func Test_UnwrapConnectorWebhookPayload_LocalLDAP(t *testing.T) { + filteredConnectors := []ConnectorContext{ + {ID: "test2", Type: "ldap", Name: "test2"}, + {ID: "test3", Type: "ldap", Name: "test3"}, + } + initialConnectors := []storage.Connector{ + {ID: "test1", Type: "ldap", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "ldap", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "ldap", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + } + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{ + {ID: "test2", Type: "ldap", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "ldap", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }) +} + +func Test_UnwrapConnectorWebhookPayload_DifferentOrder(t *testing.T) { + filteredConnectors := []ConnectorContext{ + {ID: "test3", Type: "ldap", Name: "test3"}, + {ID: "test2", Type: "ldap", Name: "test2"}, + } + initialConnectors := []storage.Connector{ + {ID: "test1", Type: "ldap", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "ldap", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "ldap", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + } + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{ + {ID: "test3", Type: "ldap", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "ldap", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }) +} + +func Test_UnwrapConnectorWebhookPayload_LocalOIDC(t *testing.T) { + filteredConnectors := []ConnectorContext{ + {ID: "test3", Type: "oidc", Name: "test3"}, + } + initialConnectors := []storage.Connector{ + {ID: "test1", Type: "oidc", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "oidc", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "oidc", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + } + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{ + {ID: "test3", Type: "oidc", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }) +} + +func Test_UnwrapConnectorWebhookPayload_NotExisting(t *testing.T) { + filteredConnectors := []ConnectorContext{ + {ID: "test4", Type: "oidc", Name: "test4"}, + } + initialConnectors := []storage.Connector{ + {ID: "test1", Type: "oidc", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "oidc", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "oidc", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + } + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{}) +} + +func Test_UnwrapConnectorWebhookPayload_NotExistingTest(t *testing.T) { + filteredConnectors := []ConnectorContext{ + {ID: "test4", Type: "oidc", Name: "test4"}, + {ID: "test3", Type: "oidc", Name: "test3"}, + } + initialConnectors := []storage.Connector{ + {ID: "test1", Type: "oidc", Name: "test1", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test2", Type: "oidc", Name: "test2", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + {ID: "test3", Type: "oidc", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + } + mappedConnectors := unwrapConnectorWebhookPayload(filteredConnectors, initialConnectors) + assert.Equal(t, mappedConnectors, []storage.Connector{ + {ID: "test3", Type: "oidc", Name: "test3", ResourceVersion: "123", Config: []byte(`{"some":"data"}`)}, + }) +} diff --git a/pkg/webhook/connectors/types.go b/pkg/webhook/connectors/types.go new file mode 100644 index 0000000000..147c2c120b --- /dev/null +++ b/pkg/webhook/connectors/types.go @@ -0,0 +1,48 @@ +package connectors + +import ( + "net/http" + + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/pkg/webhook/helpers" + "github.com/dexidp/dex/storage" +) + +type FilterCaller interface { + CallHook(connectors []storage.Connector, r *http.Request) ([]storage.Connector, error) +} + +type ConnectorFilterHook struct { + // Name is the name of the webhook + Name string `json:"name"` + // Config is the configuration of the webhook + FilterInvoker FilterCaller `json:"filterInvoker"` +} + +var _ FilterCaller = &WebhookFilterCaller{} + +type WebhookFilterCaller struct { + RequestScope *config.HookRequestScope + transportHelper helpers.WebhookHTTPHelpers +} + +type ConnectorContext struct { + // ID that will uniquely identify the connector object. + ID string `json:"id"` + // The Type of the connector. E.g. 'oidc' or 'ldap' + Type string `json:"type"` + // The Name of the connector that is used when displaying it to the end user. + Name string `json:"name"` +} + +type RequestContext struct { + // Headers is the headers of the request + Headers map[string][]string `json:"headers"` + // Params is the params of the request + Params map[string][]string `json:"params"` +} + +type FilterWebhookPayload struct { + Connectors []ConnectorContext `json:"connID"` + Request RequestContext `json:"requestContext"` +} diff --git a/server/handlers.go b/server/handlers.go index 08a60d48da..2cde607b7c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -144,6 +144,24 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } + FilteredConnectors := []storage.Connector{} + FilteredConnectors = append(FilteredConnectors, connectors...) + // Webhook to filter out providers + + s.logger.Infof("Connectors: %d", len(FilteredConnectors)) + s.logger.Infof("Invoking %d hooks to filter connectors", len(s.connectorWebhookFilter)) + for _, c := range s.connectorWebhookFilter { + s.logger.Infof("Calling connectors webhook %s", c.Name) + FilteredConnectors, err = c.FilterInvoker.CallHook(FilteredConnectors, r) + if err != nil { + s.logger.Errorf("Failed to filter connectors: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") + return + } + s.logger.Infof("Connectors after webhook %s: %d", c.Name, len(FilteredConnectors)) + } + connectors = FilteredConnectors + // We don't need connector_id any more r.Form.Del("connector_id") @@ -217,6 +235,33 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } + connObj, err := s.storage.GetConnector(connID) + if err != nil { + s.logger.Errorf("Failed to get connector: %v", err) + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") + return + } + + s.logger.Infof("Connector Validation: %s", connID) + s.logger.Infof("Invoking %d hooks to filter connectors", len(s.connectorWebhookFilter)) + filteredConnectors := []storage.Connector{connObj} + for _, c := range s.connectorWebhookFilter { + s.logger.Infof("Calling connectors webhook %s", c.Name) + filteredConnectors, err = c.FilterInvoker.CallHook(filteredConnectors, r) + if err != nil { + s.logger.Errorf("Failed to filter connectors: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") + return + } + s.logger.Infof("Connectors after webhook %s: %d", c.Name, len(filteredConnectors)) + } + + if len(filteredConnectors) == 0 { + s.logger.Errorf("Connector %s not valid for the current request", connID) + s.renderError(r, w, http.StatusBadRequest, "Invalid Request") + return + } + // Set the connector being used for the login. if authReq.ConnectorID != "" && authReq.ConnectorID != connID { s.logger.Errorf("Mismatched connector ID in auth request: %s vs %s", diff --git a/server/server.go b/server/server.go index bf83dd81f0..41d4238a0f 100644 --- a/server/server.go +++ b/server/server.go @@ -43,6 +43,8 @@ import ( "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" "github.com/dexidp/dex/pkg/log" + "github.com/dexidp/dex/pkg/webhook/config" + "github.com/dexidp/dex/pkg/webhook/connectors" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/web" ) @@ -81,6 +83,10 @@ type Config struct { // Logging in implies approval. SkipApprovalScreen bool + // ConnectorFilterHooks is a list of hooks that can be used to filter the connectors returned by dex in the + // login page. + ConnectorFilterHooks config.ConnectorFilterHooks + // If enabled, the connectors selection page will always be shown even if there's only one AlwaysShowLoginScreen bool @@ -184,6 +190,8 @@ type Server struct { refreshTokenPolicy *RefreshTokenPolicy logger log.Logger + + connectorWebhookFilter []*connectors.ConnectorFilterHook } // NewServer constructs a server from the provided config. @@ -281,6 +289,15 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) now = time.Now } + connectorFilters := make([]*connectors.ConnectorFilterHook, 0) + for _, hook := range c.ConnectorFilterHooks.FilterHooks { + filter, err := connectors.NewConnectorFilter(hook) + if err != nil { + return nil, err + } + connectorFilters = append(connectorFilters, filter) + } + s := &Server{ issuerURL: *issuerURL, connectors: make(map[string]Connector), @@ -297,6 +314,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) templates: tmpls, passwordConnector: c.PasswordConnector, logger: c.Logger, + connectorWebhookFilter: connectorFilters, } // Retrieves connector objects in backend storage. This list includes the static connectors