From c9c2cbc53c23c361d7aa69a9e110be1ffa22f972 Mon Sep 17 00:00:00 2001 From: Ross Kinder Date: Sun, 7 Jan 2018 20:19:58 -0500 Subject: [PATCH] samlsp: remove X-Saml headers in favor of attaching Claims to request context (#131) --- README.md | 2 +- example/trivial/trivial.go | 2 +- saml.go | 3 +- samlsp/middleware.go | 79 ++++++++++++++------------------------ samlsp/middleware_test.go | 65 ++++++------------------------- samlsp/token.go | 47 +++++++++++++++++++++++ 6 files changed, 92 insertions(+), 106 deletions(-) create mode 100644 samlsp/token.go diff --git a/README.md b/README.md index 68a55718..1b726360 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ import ( ) func hello(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn")) + fmt.Fprintf(w, "Hello, %s!", samlsp.Token(r.Context()).Attributes.Get("cn")) } func main() { diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go index b120c862..4cc3c11a 100644 --- a/example/trivial/trivial.go +++ b/example/trivial/trivial.go @@ -14,7 +14,7 @@ import ( ) func hello(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn")) + fmt.Fprintf(w, "Hello, %s!", samlsp.Token(r.Context()).Attributes.Get("cn")) } func main() { diff --git a/saml.go b/saml.go index cb2b2def..afdcea7e 100644 --- a/saml.go +++ b/saml.go @@ -68,7 +68,8 @@ // ) // // func hello(w http.ResponseWriter, r *http.Request) { -// fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn")) +// claims := samlsp.Claims(r.Context()) +// fmt.Fprintf(w, "Hello, %s!", claims.Attributes["cn"][0]) // } // // func main() { diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 3e61ea44..acb44c3c 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -19,7 +19,7 @@ import ( // It implements http.Handler so that it can provide the metadata and ACS endpoints, // typically /saml/metadata and /saml/acs, respectively. // -// It also provides middleware, RequireAccount which redirects users to +// It also provides middleware RequireAccount which redirects users to // the auth process if they do not have session credentials. // // When redirecting the user through the SAML auth flow, the middlware assigns @@ -37,12 +37,9 @@ import ( // authenticated attributes from the SAML assertion. // // When the middlware receives a request with a valid session JWT it extracts -// the SAML attributes and modifies the http.Request object adding headers -// corresponding to the specified attributes. For example, if the attribute -// "cn" were present in the initial assertion with a value of "Alice Smith", -// then a corresponding header "X-Saml-Cn" will be added to the request with -// a value of "Alice Smith". For safety, the middleware strips out any existing -// headers that begin with "X-Saml-". +// the SAML attributes and modifies the http.Request object adding a Context +// object to the request context that contains attributes from the initial +// SAML assertion. // // When issuing JSON Web Tokens, a signing key is required. Because the // SAML service provider already has a private key, we borrow that key @@ -105,7 +102,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // to start the SAML auth flow. func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if m.IsAuthorized(r) { + if token := m.GetAuthorizationToken(r); token != nil { + r = r.WithContext(WithToken(r.Context(), token)) handler.ServeHTTP(w, r) return } @@ -209,11 +207,6 @@ func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string { return rv } -type TokenClaims struct { - jwt.StandardClaims - Attributes map[string][]string `json:"attr"` -} - // Authorize is invoked by ServeHTTP when we have a new, valid SAML assertion. // It sets a cookie that contains a signed JWT containing the assertion attributes. // It then redirects the user's browser to the original URL contained in RelayState. @@ -250,7 +243,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion } now := saml.TimeNow() - claims := TokenClaims{} + claims := AuthorizationToken{} claims.Audience = m.ServiceProvider.Metadata().EntityID claims.IssuedAt = now.Unix() claims.ExpiresAt = now.Add(m.CookieMaxAge).Unix() @@ -272,6 +265,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion } } } + signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBlock) if err != nil { @@ -291,64 +285,49 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion http.Redirect(w, r, redirectURI, http.StatusFound) } -// IsAuthorized is invoked by RequireAccount to determine if the request -// is already authorized or if the user's browser should be redirected to the -// SAML login flow. If the request is authorized, then the request headers -// starting with X-Saml- for each SAML assertion attribute are set. For example, -// if an attribute "uid" has the value "alice@example.com", then the following -// header would be added to the request: +// IsAuthorized returns true if the request has already been authorized. // -// X-Saml-Uid: alice@example.com -// -// It is an error for this function to be invoked with a request containing -// any headers starting with X-Saml. This function will panic if you do. +// Note: This function is retained for compatability. Use GetAuthorizationToken in new code +// instead. func (m *Middleware) IsAuthorized(r *http.Request) bool { + return m.GetAuthorizationToken(r) != nil +} + +// GetAuthorizationToken is invoked by RequireAccount to determine if the request +// is already authorized or if the user's browser should be redirected to the +// SAML login flow. If the request is authorized, then the request context is +// ammended with a Context object. +func (m *Middleware) GetAuthorizationToken(r *http.Request) *AuthorizationToken { cookie, err := r.Cookie(m.CookieName) if err != nil { - return false + return nil } - tokenClaims := TokenClaims{} + tokenClaims := AuthorizationToken{} token, err := jwt.ParseWithClaims(cookie.Value, &tokenClaims, func(t *jwt.Token) (interface{}, error) { secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key) return secretBlock, nil }) if err != nil || !token.Valid { m.ServiceProvider.Logger.Printf("ERROR: invalid token: %s", err) - return false + return nil } if err := tokenClaims.StandardClaims.Valid(); err != nil { m.ServiceProvider.Logger.Printf("ERROR: invalid token claims: %s", err) - return false + return nil } if tokenClaims.Audience != m.ServiceProvider.Metadata().EntityID { m.ServiceProvider.Logger.Printf("ERROR: invalid audience: %s", err) - return false - } - - // It is an error for the request to include any X-SAML* headers, - // because those might be confused with ours. If we encounter any - // such headers, we abort the request, so there is no confustion. - for headerName := range r.Header { - if strings.HasPrefix(headerName, "X-Saml") { - panic("X-Saml-* headers should not exist when this function is called") - } - } - - for claimName, claimValues := range tokenClaims.Attributes { - for _, claimValue := range claimValues { - r.Header.Add("X-Saml-"+claimName, claimValue) - } + return nil } - r.Header.Set("X-Saml-Subject", tokenClaims.Subject) - return true + return &tokenClaims } // RequireAttribute returns a middleware function that requires that the // SAML attribute `name` be set to `value`. This can be used to require -// that a remote user be a member of a group. It relies on the X-Saml-* headers -// that RequireAccount adds to the request. +// that a remote user be a member of a group. It relies on the Claims assigned +// to to the context in RequireAccount. // // For example: // @@ -358,8 +337,8 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool { func RequireAttribute(name, value string) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if values, ok := r.Header[http.CanonicalHeaderKey(fmt.Sprintf("X-Saml-%s", name))]; ok { - for _, actualValue := range values { + if claims := Token(r.Context()); claims != nil { + for _, actualValue := range claims.Attributes[name] { if actualValue == value { handler.ServeHTTP(w, r) return diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index fe46e861..4c4fbd07 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rsa" "crypto/sha256" + "crypto/x509" "encoding/base64" "encoding/xml" "io/ioutil" @@ -17,8 +18,6 @@ import ( dsig "github.com/russellhaering/goxmldsig" . "gopkg.in/check.v1" - "crypto/x509" - "github.com/crewjam/saml" "github.com/crewjam/saml/logger" "github.com/crewjam/saml/testsaml" @@ -218,16 +217,17 @@ func (test *MiddlewareTest) TestRequireAccountNoCredsPostBinding(c *C) { func (test *MiddlewareTest) TestRequireAccountCreds(c *C) { handler := test.Middleware.RequireAccount( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.Assert(r.Header.Get("X-Saml-Telephonenumber"), Equals, "555-5555") - c.Assert(r.Header["X-Saml-Edupersonscopedaffiliation"], DeepEquals, []string{"Member@testshib.org", "Staff@testshib.org"}) - c.Assert(r.Header.Get("X-Saml-Sn"), Equals, "And I") - c.Assert(r.Header.Get("X-Saml-Edupersonentitlement"), Equals, "urn:mace:dir:entitlement:common-lib-terms") - c.Assert(r.Header.Get("X-Saml-Edupersontargetedid"), Equals, "") - c.Assert(r.Header.Get("X-Saml-Givenname"), Equals, "Me Myself") - c.Assert(r.Header.Get("X-Saml-Cn"), Equals, "Me Myself And I") - c.Assert(r.Header["X-Saml-Edupersonaffiliation"], DeepEquals, []string{"Member", "Staff"}) - c.Assert(r.Header.Get("X-Saml-Uid"), Equals, "myself") - c.Assert(r.Header.Get("X-Saml-Edupersonprincipalname"), Equals, "myself@testshib.org") + token := Token(r.Context()) + c.Assert(token.Attributes.Get("telephoneNumber"), DeepEquals, "555-5555") + c.Assert(token.Attributes.Get("sn"), Equals, "And I") + c.Assert(token.Attributes.Get("eduPersonEntitlement"), Equals, "urn:mace:dir:entitlement:common-lib-terms") + c.Assert(token.Attributes.Get("eduPersonTargetedID"), Equals, "") + c.Assert(token.Attributes.Get("givenName"), Equals, "Me Myself") + c.Assert(token.Attributes.Get("cn"), Equals, "Me Myself And I") + c.Assert(token.Attributes.Get("uid"), Equals, "myself") + c.Assert(token.Attributes.Get("eduPersonPrincipalName"), Equals, "myself@testshib.org") + c.Assert(token.Attributes["eduPersonScopedAffiliation"], DeepEquals, []string{"Member@testshib.org", "Staff@testshib.org"}) + c.Assert(token.Attributes["eduPersonAffiliation"], DeepEquals, []string{"Member", "Staff"}) w.WriteHeader(http.StatusTeapot) })) @@ -241,30 +241,6 @@ func (test *MiddlewareTest) TestRequireAccountCreds(c *C) { c.Assert(resp.Code, Equals, http.StatusTeapot) } -func (test *MiddlewareTest) TestFiltersSpecialHeadersInRequest(c *C) { - handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic("not reached") - })) - - { - req, _ := http.NewRequest("GET", "/frob", nil) - req.Header.Set("X-Saml-Uid", "root") // evil - req.Header.Set("Cookie", "ttt="+expectedToken+"; Path=/; Max-Age=7200") - resp := httptest.NewRecorder() - c.Assert(func() { handler.ServeHTTP(resp, req) }, PanicMatches, "X-Saml-\\* headers should not exist when this function is called") - } - - // make sure case folding works - { - req, _ := http.NewRequest("GET", "/frob", nil) - req.Header.Set("x-SAML-uId", "root") // evil - req.Header.Set("Cookie", "ttt="+expectedToken+"; Path=/; Max-Age=7200") - resp := httptest.NewRecorder() - c.Assert(func() { handler.ServeHTTP(resp, req) }, PanicMatches, "X-Saml-\\* headers should not exist when this function is called") - } -} - func (test *MiddlewareTest) TestRequireAccountBadCreds(c *C) { handler := test.Middleware.RequireAccount( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -289,7 +265,6 @@ func (test *MiddlewareTest) TestRequireAccountBadCreds(c *C) { decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) c.Assert(err, IsNil) c.Assert(string(decodedRequest), Equals, "https://15661444.ngrok.io/saml2/metadata") - } func (test *MiddlewareTest) TestRequireAccountExpiredCreds(c *C) { @@ -335,22 +310,6 @@ func (test *MiddlewareTest) TestRequireAccountPanicOnRequestToACS(c *C) { "don't wrap Middleware with RequireAccount") } -func (test *MiddlewareTest) TestRejectRequestWithMagicHeader(c *C) { - handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic("not reached") - })) - - req, _ := http.NewRequest("GET", "/frob", nil) - req.Header.Set("Cookie", ""+ - "ttt="+expectedToken+"; "+ - "Path=/; Max-Age=7200") - req.Header.Set("X-Saml-Uid", "root") // ... evil - resp := httptest.NewRecorder() - c.Assert(func() { handler.ServeHTTP(resp, req) }, Panics, - "X-Saml-* headers should not exist when this function is called") -} - func (test *MiddlewareTest) TestRequireAttribute(c *C) { handler := test.Middleware.RequireAccount( RequireAttribute("eduPersonAffiliation", "Staff")( diff --git a/samlsp/token.go b/samlsp/token.go new file mode 100644 index 00000000..b4f239e9 --- /dev/null +++ b/samlsp/token.go @@ -0,0 +1,47 @@ +package samlsp + +import ( + "context" + + jwt "github.com/dgrijalva/jwt-go" +) + +// AuthorizationToken represents the data stored in the authorization cookie. +type AuthorizationToken struct { + jwt.StandardClaims + Attributes Attributes `json:"attr"` +} + +// Attributes is a map of attributes provided in the SAML assertion +type Attributes map[string][]string + +// Get returns the first attribute named `key` or an empty string if +// no such attributes is present. +func (a Attributes) Get(key string) string { + if a == nil { + return "" + } + v := a[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +type indexType int + +const tokenIndex indexType = iota + +// Token returns the token associated with ctx, or nil if no token are associated +func Token(ctx context.Context) *AuthorizationToken { + v := ctx.Value(tokenIndex) + if v == nil { + return nil + } + return v.(*AuthorizationToken) +} + +// WithToken returns a new context with token associated +func WithToken(ctx context.Context, token *AuthorizationToken) context.Context { + return context.WithValue(ctx, tokenIndex, token) +}