Skip to content

Commit

Permalink
samlsp: remove X-Saml headers in favor of attaching Claims to request…
Browse files Browse the repository at this point in the history
… context (#131)
  • Loading branch information
crewjam authored Jan 8, 2018
1 parent e9d713d commit c9c2cbc
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 106 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import (
)

func hello(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn"))
fmt.Fprintf(w, "Hello, %s!", samlsp.Token(r.Context()).Attributes.Get("cn"))
}

func main() {
Expand Down
2 changes: 1 addition & 1 deletion example/trivial/trivial.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func hello(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn"))
fmt.Fprintf(w, "Hello, %s!", samlsp.Token(r.Context()).Attributes.Get("cn"))
}

func main() {
Expand Down
3 changes: 2 additions & 1 deletion saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
// )
//
// func hello(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(w, "Hello, %s!", r.Header.Get("X-Saml-Cn"))
// claims := samlsp.Claims(r.Context())
// fmt.Fprintf(w, "Hello, %s!", claims.Attributes["cn"][0])
// }
//
// func main() {
Expand Down
79 changes: 29 additions & 50 deletions samlsp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// It implements http.Handler so that it can provide the metadata and ACS endpoints,
// typically /saml/metadata and /saml/acs, respectively.
//
// It also provides middleware, RequireAccount which redirects users to
// It also provides middleware RequireAccount which redirects users to
// the auth process if they do not have session credentials.
//
// When redirecting the user through the SAML auth flow, the middlware assigns
Expand All @@ -37,12 +37,9 @@ import (
// authenticated attributes from the SAML assertion.
//
// When the middlware receives a request with a valid session JWT it extracts
// the SAML attributes and modifies the http.Request object adding headers
// corresponding to the specified attributes. For example, if the attribute
// "cn" were present in the initial assertion with a value of "Alice Smith",
// then a corresponding header "X-Saml-Cn" will be added to the request with
// a value of "Alice Smith". For safety, the middleware strips out any existing
// headers that begin with "X-Saml-".
// the SAML attributes and modifies the http.Request object adding a Context
// object to the request context that contains attributes from the initial
// SAML assertion.
//
// When issuing JSON Web Tokens, a signing key is required. Because the
// SAML service provider already has a private key, we borrow that key
Expand Down Expand Up @@ -105,7 +102,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// to start the SAML auth flow.
func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if m.IsAuthorized(r) {
if token := m.GetAuthorizationToken(r); token != nil {
r = r.WithContext(WithToken(r.Context(), token))
handler.ServeHTTP(w, r)
return
}
Expand Down Expand Up @@ -209,11 +207,6 @@ func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string {
return rv
}

type TokenClaims struct {
jwt.StandardClaims
Attributes map[string][]string `json:"attr"`
}

// Authorize is invoked by ServeHTTP when we have a new, valid SAML assertion.
// It sets a cookie that contains a signed JWT containing the assertion attributes.
// It then redirects the user's browser to the original URL contained in RelayState.
Expand Down Expand Up @@ -250,7 +243,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
}

now := saml.TimeNow()
claims := TokenClaims{}
claims := AuthorizationToken{}
claims.Audience = m.ServiceProvider.Metadata().EntityID
claims.IssuedAt = now.Unix()
claims.ExpiresAt = now.Add(m.CookieMaxAge).Unix()
Expand All @@ -272,6 +265,7 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
}
}
}

signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256,
claims).SignedString(secretBlock)
if err != nil {
Expand All @@ -291,64 +285,49 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
http.Redirect(w, r, redirectURI, http.StatusFound)
}

// IsAuthorized is invoked by RequireAccount to determine if the request
// is already authorized or if the user's browser should be redirected to the
// SAML login flow. If the request is authorized, then the request headers
// starting with X-Saml- for each SAML assertion attribute are set. For example,
// if an attribute "uid" has the value "alice@example.com", then the following
// header would be added to the request:
// IsAuthorized returns true if the request has already been authorized.
//
// X-Saml-Uid: alice@example.com
//
// It is an error for this function to be invoked with a request containing
// any headers starting with X-Saml. This function will panic if you do.
// Note: This function is retained for compatability. Use GetAuthorizationToken in new code
// instead.
func (m *Middleware) IsAuthorized(r *http.Request) bool {
return m.GetAuthorizationToken(r) != nil
}

// GetAuthorizationToken is invoked by RequireAccount to determine if the request
// is already authorized or if the user's browser should be redirected to the
// SAML login flow. If the request is authorized, then the request context is
// ammended with a Context object.
func (m *Middleware) GetAuthorizationToken(r *http.Request) *AuthorizationToken {
cookie, err := r.Cookie(m.CookieName)
if err != nil {
return false
return nil
}

tokenClaims := TokenClaims{}
tokenClaims := AuthorizationToken{}
token, err := jwt.ParseWithClaims(cookie.Value, &tokenClaims, func(t *jwt.Token) (interface{}, error) {
secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)
return secretBlock, nil
})
if err != nil || !token.Valid {
m.ServiceProvider.Logger.Printf("ERROR: invalid token: %s", err)
return false
return nil
}
if err := tokenClaims.StandardClaims.Valid(); err != nil {
m.ServiceProvider.Logger.Printf("ERROR: invalid token claims: %s", err)
return false
return nil
}
if tokenClaims.Audience != m.ServiceProvider.Metadata().EntityID {
m.ServiceProvider.Logger.Printf("ERROR: invalid audience: %s", err)
return false
}

// It is an error for the request to include any X-SAML* headers,
// because those might be confused with ours. If we encounter any
// such headers, we abort the request, so there is no confustion.
for headerName := range r.Header {
if strings.HasPrefix(headerName, "X-Saml") {
panic("X-Saml-* headers should not exist when this function is called")
}
}

for claimName, claimValues := range tokenClaims.Attributes {
for _, claimValue := range claimValues {
r.Header.Add("X-Saml-"+claimName, claimValue)
}
return nil
}
r.Header.Set("X-Saml-Subject", tokenClaims.Subject)

return true
return &tokenClaims
}

// RequireAttribute returns a middleware function that requires that the
// SAML attribute `name` be set to `value`. This can be used to require
// that a remote user be a member of a group. It relies on the X-Saml-* headers
// that RequireAccount adds to the request.
// that a remote user be a member of a group. It relies on the Claims assigned
// to to the context in RequireAccount.
//
// For example:
//
Expand All @@ -358,8 +337,8 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool {
func RequireAttribute(name, value string) func(http.Handler) http.Handler {
return func(handler http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if values, ok := r.Header[http.CanonicalHeaderKey(fmt.Sprintf("X-Saml-%s", name))]; ok {
for _, actualValue := range values {
if claims := Token(r.Context()); claims != nil {
for _, actualValue := range claims.Attributes[name] {
if actualValue == value {
handler.ServeHTTP(w, r)
return
Expand Down
65 changes: 12 additions & 53 deletions samlsp/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/xml"
"io/ioutil"
Expand All @@ -17,8 +18,6 @@ import (
dsig "github.com/russellhaering/goxmldsig"
. "gopkg.in/check.v1"

"crypto/x509"

"github.com/crewjam/saml"
"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/testsaml"
Expand Down Expand Up @@ -218,16 +217,17 @@ func (test *MiddlewareTest) TestRequireAccountNoCredsPostBinding(c *C) {
func (test *MiddlewareTest) TestRequireAccountCreds(c *C) {
handler := test.Middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Assert(r.Header.Get("X-Saml-Telephonenumber"), Equals, "555-5555")
c.Assert(r.Header["X-Saml-Edupersonscopedaffiliation"], DeepEquals, []string{"Member@testshib.org", "Staff@testshib.org"})
c.Assert(r.Header.Get("X-Saml-Sn"), Equals, "And I")
c.Assert(r.Header.Get("X-Saml-Edupersonentitlement"), Equals, "urn:mace:dir:entitlement:common-lib-terms")
c.Assert(r.Header.Get("X-Saml-Edupersontargetedid"), Equals, "")
c.Assert(r.Header.Get("X-Saml-Givenname"), Equals, "Me Myself")
c.Assert(r.Header.Get("X-Saml-Cn"), Equals, "Me Myself And I")
c.Assert(r.Header["X-Saml-Edupersonaffiliation"], DeepEquals, []string{"Member", "Staff"})
c.Assert(r.Header.Get("X-Saml-Uid"), Equals, "myself")
c.Assert(r.Header.Get("X-Saml-Edupersonprincipalname"), Equals, "myself@testshib.org")
token := Token(r.Context())
c.Assert(token.Attributes.Get("telephoneNumber"), DeepEquals, "555-5555")
c.Assert(token.Attributes.Get("sn"), Equals, "And I")
c.Assert(token.Attributes.Get("eduPersonEntitlement"), Equals, "urn:mace:dir:entitlement:common-lib-terms")
c.Assert(token.Attributes.Get("eduPersonTargetedID"), Equals, "")
c.Assert(token.Attributes.Get("givenName"), Equals, "Me Myself")
c.Assert(token.Attributes.Get("cn"), Equals, "Me Myself And I")
c.Assert(token.Attributes.Get("uid"), Equals, "myself")
c.Assert(token.Attributes.Get("eduPersonPrincipalName"), Equals, "myself@testshib.org")
c.Assert(token.Attributes["eduPersonScopedAffiliation"], DeepEquals, []string{"Member@testshib.org", "Staff@testshib.org"})
c.Assert(token.Attributes["eduPersonAffiliation"], DeepEquals, []string{"Member", "Staff"})
w.WriteHeader(http.StatusTeapot)
}))

Expand All @@ -241,30 +241,6 @@ func (test *MiddlewareTest) TestRequireAccountCreds(c *C) {
c.Assert(resp.Code, Equals, http.StatusTeapot)
}

func (test *MiddlewareTest) TestFiltersSpecialHeadersInRequest(c *C) {
handler := test.Middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("not reached")
}))

{
req, _ := http.NewRequest("GET", "/frob", nil)
req.Header.Set("X-Saml-Uid", "root") // evil
req.Header.Set("Cookie", "ttt="+expectedToken+"; Path=/; Max-Age=7200")
resp := httptest.NewRecorder()
c.Assert(func() { handler.ServeHTTP(resp, req) }, PanicMatches, "X-Saml-\\* headers should not exist when this function is called")
}

// make sure case folding works
{
req, _ := http.NewRequest("GET", "/frob", nil)
req.Header.Set("x-SAML-uId", "root") // evil
req.Header.Set("Cookie", "ttt="+expectedToken+"; Path=/; Max-Age=7200")
resp := httptest.NewRecorder()
c.Assert(func() { handler.ServeHTTP(resp, req) }, PanicMatches, "X-Saml-\\* headers should not exist when this function is called")
}
}

func (test *MiddlewareTest) TestRequireAccountBadCreds(c *C) {
handler := test.Middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -289,7 +265,6 @@ func (test *MiddlewareTest) TestRequireAccountBadCreds(c *C) {
decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
c.Assert(err, IsNil)
c.Assert(string(decodedRequest), Equals, "<samlp:AuthnRequest xmlns:saml=\"urn:oasis:names:tc:SAML:2.0:assertion\" xmlns:samlp=\"urn:oasis:names:tc:SAML:2.0:protocol\" ID=\"id-00020406080a0c0e10121416181a1c1e20222426\" Version=\"2.0\" IssueInstant=\"2015-12-01T01:57:09.123Z\" Destination=\"https://idp.testshib.org/idp/profile/SAML2/Redirect/SSO\" AssertionConsumerServiceURL=\"https://15661444.ngrok.io/saml2/acs\" ProtocolBinding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"><saml:Issuer Format=\"urn:oasis:names:tc:SAML:2.0:nameid-format:entity\">https://15661444.ngrok.io/saml2/metadata</saml:Issuer><samlp:NameIDPolicy Format=\"urn:oasis:names:tc:SAML:2.0:nameid-format:transient\" AllowCreate=\"true\"/></samlp:AuthnRequest>")

}

func (test *MiddlewareTest) TestRequireAccountExpiredCreds(c *C) {
Expand Down Expand Up @@ -335,22 +310,6 @@ func (test *MiddlewareTest) TestRequireAccountPanicOnRequestToACS(c *C) {
"don't wrap Middleware with RequireAccount")
}

func (test *MiddlewareTest) TestRejectRequestWithMagicHeader(c *C) {
handler := test.Middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("not reached")
}))

req, _ := http.NewRequest("GET", "/frob", nil)
req.Header.Set("Cookie", ""+
"ttt="+expectedToken+"; "+
"Path=/; Max-Age=7200")
req.Header.Set("X-Saml-Uid", "root") // ... evil
resp := httptest.NewRecorder()
c.Assert(func() { handler.ServeHTTP(resp, req) }, Panics,
"X-Saml-* headers should not exist when this function is called")
}

func (test *MiddlewareTest) TestRequireAttribute(c *C) {
handler := test.Middleware.RequireAccount(
RequireAttribute("eduPersonAffiliation", "Staff")(
Expand Down
47 changes: 47 additions & 0 deletions samlsp/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package samlsp

import (
"context"

jwt "github.com/dgrijalva/jwt-go"
)

// AuthorizationToken represents the data stored in the authorization cookie.
type AuthorizationToken struct {
jwt.StandardClaims
Attributes Attributes `json:"attr"`
}

// Attributes is a map of attributes provided in the SAML assertion
type Attributes map[string][]string

// Get returns the first attribute named `key` or an empty string if
// no such attributes is present.
func (a Attributes) Get(key string) string {
if a == nil {
return ""
}
v := a[key]
if len(v) == 0 {
return ""
}
return v[0]
}

type indexType int

const tokenIndex indexType = iota

// Token returns the token associated with ctx, or nil if no token are associated
func Token(ctx context.Context) *AuthorizationToken {
v := ctx.Value(tokenIndex)
if v == nil {
return nil
}
return v.(*AuthorizationToken)
}

// WithToken returns a new context with token associated
func WithToken(ctx context.Context, token *AuthorizationToken) context.Context {
return context.WithValue(ctx, tokenIndex, token)
}

0 comments on commit c9c2cbc

Please sign in to comment.