diff --git a/client_authentication.go b/client_authentication.go index 0ebf16080..8e611cc1c 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -181,6 +181,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt switch exp := claims["exp"].(type) { case float64: expiry = int64(exp) + case int64: + expiry = exp case json.Number: expiry, err = exp.Int64() default: diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index 29ed83a1b..901a8c647 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -74,19 +74,19 @@ func (c *IDTokenClaims) ToMap() map[string]interface{} { } if !c.IssuedAt.IsZero() { - ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype + ret["iat"] = c.IssuedAt.Unix() } else { delete(ret, "iat") } if !c.ExpiresAt.IsZero() { - ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype + ret["exp"] = c.ExpiresAt.Unix() } else { delete(ret, "exp") } if !c.RequestedAt.IsZero() { - ret["rat"] = float64(c.RequestedAt.Unix()) // jwt-go does not support int64 as datatype + ret["rat"] = c.RequestedAt.Unix() } else { delete(ret, "rat") } diff --git a/token/jwt/claims_id_token_test.go b/token/jwt/claims_id_token_test.go index abd09f466..62af493fa 100644 --- a/token/jwt/claims_id_token_test.go +++ b/token/jwt/claims_id_token_test.go @@ -61,11 +61,11 @@ func TestIDTokenClaimsToMap(t *testing.T) { assert.Equal(t, map[string]interface{}{ "jti": idTokenClaims.JTI, "sub": idTokenClaims.Subject, - "iat": float64(idTokenClaims.IssuedAt.Unix()), - "rat": float64(idTokenClaims.RequestedAt.Unix()), + "iat": idTokenClaims.IssuedAt.Unix(), + "rat": idTokenClaims.RequestedAt.Unix(), "iss": idTokenClaims.Issuer, "aud": idTokenClaims.Audience, - "exp": float64(idTokenClaims.ExpiresAt.Unix()), + "exp": idTokenClaims.ExpiresAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], "at_hash": idTokenClaims.AccessTokenHash, @@ -79,11 +79,11 @@ func TestIDTokenClaimsToMap(t *testing.T) { assert.Equal(t, map[string]interface{}{ "jti": idTokenClaims.JTI, "sub": idTokenClaims.Subject, - "iat": float64(idTokenClaims.IssuedAt.Unix()), - "rat": float64(idTokenClaims.RequestedAt.Unix()), + "iat": idTokenClaims.IssuedAt.Unix(), + "rat": idTokenClaims.RequestedAt.Unix(), "iss": idTokenClaims.Issuer, "aud": idTokenClaims.Audience, - "exp": float64(idTokenClaims.ExpiresAt.Unix()), + "exp": idTokenClaims.ExpiresAt.Unix(), "foo": idTokenClaims.Extra["foo"], "baz": idTokenClaims.Extra["baz"], "at_hash": idTokenClaims.AccessTokenHash, diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index e48f063fb..3540a0b95 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -126,19 +126,19 @@ func (c *JWTClaims) ToMap() map[string]interface{} { } if !c.IssuedAt.IsZero() { - ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype + ret["iat"] = c.IssuedAt.Unix() } else { delete(ret, "iat") } if !c.NotBefore.IsZero() { - ret["nbf"] = float64(c.NotBefore.Unix()) // jwt-go does not support int64 as datatype + ret["nbf"] = c.NotBefore.Unix() } else { delete(ret, "nbf") } if !c.ExpiresAt.IsZero() { - ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype + ret["exp"] = c.ExpiresAt.Unix() } else { delete(ret, "exp") } @@ -183,38 +183,23 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) { c.Audience = s } case "iat": - switch v.(type) { - case float64: - c.IssuedAt = time.Unix(int64(v.(float64)), 0).UTC() - case int64: - c.IssuedAt = time.Unix(v.(int64), 0).UTC() - } + c.IssuedAt = toTime(v, c.IssuedAt) case "nbf": - switch v.(type) { - case float64: - c.NotBefore = time.Unix(int64(v.(float64)), 0).UTC() - case int64: - c.NotBefore = time.Unix(v.(int64), 0).UTC() - } + c.NotBefore = toTime(v, c.NotBefore) case "exp": - switch v.(type) { - case float64: - c.ExpiresAt = time.Unix(int64(v.(float64)), 0).UTC() - case int64: - c.ExpiresAt = time.Unix(v.(int64), 0).UTC() - } + c.ExpiresAt = toTime(v, c.ExpiresAt) case "scp": - switch v.(type) { + switch s := v.(type) { case []string: - c.Scope = v.([]string) + c.Scope = s if c.ScopeField == JWTScopeFieldString { c.ScopeField = JWTScopeFieldBoth } else if c.ScopeField == JWTScopeFieldUnset { c.ScopeField = JWTScopeFieldList } case []interface{}: - c.Scope = make([]string, len(v.([]interface{}))) - for i, vi := range v.([]interface{}) { + c.Scope = make([]string, len(s)) + for i, vi := range s { if s, ok := vi.(string); ok { c.Scope[i] = s } @@ -240,6 +225,17 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) { } } +func toTime(v interface{}, def time.Time) (t time.Time) { + t = def + switch a := v.(type) { + case float64: + t = time.Unix(int64(a), 0).UTC() + case int64: + t = time.Unix(a, 0).UTC() + } + return +} + // Add will add a key-value pair to the extra field func (c *JWTClaims) Add(key string, value interface{}) { if c.Extra == nil { diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index fd3194165..3756a7859 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -48,11 +48,11 @@ var jwtClaims = &JWTClaims{ var jwtClaimsMap = map[string]interface{}{ "sub": jwtClaims.Subject, - "iat": float64(jwtClaims.IssuedAt.Unix()), + "iat": jwtClaims.IssuedAt.Unix(), "iss": jwtClaims.Issuer, - "nbf": float64(jwtClaims.NotBefore.Unix()), + "nbf": jwtClaims.NotBefore.Unix(), "aud": jwtClaims.Audience, - "exp": float64(jwtClaims.ExpiresAt.Unix()), + "exp": jwtClaims.ExpiresAt.Unix(), "jti": jwtClaims.JTI, "scp": []string{"email", "offline"}, "foo": jwtClaims.Extra["foo"], @@ -105,7 +105,7 @@ func TestScopeFieldString(t *testing.T) { func TestScopeFieldBoth(t *testing.T) { jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth) - // Making a copy of jwtClaimsMap. + // Making a copy of jwtClaimsMap jwtClaimsMapWithBoth := jwtClaims.ToMap() jwtClaimsMapWithBoth["scope"] = "email offline" assert.Equal(t, jwtClaimsMapWithBoth, map[string]interface{}(jwtClaimsWithBoth.ToMapClaims())) diff --git a/token/jwt/map_claims.go b/token/jwt/map_claims.go index c5b4e988f..498332dec 100644 --- a/token/jwt/map_claims.go +++ b/token/jwt/map_claims.go @@ -1,11 +1,14 @@ package jwt import ( + "bytes" "crypto/subtle" "encoding/json" "errors" "time" - // "fmt" + + "github.com/ory/x/errorsx" + jjson "gopkg.in/square/go-jose.v2/json" ) var TimeFunc = time.Now @@ -43,11 +46,7 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { // Compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - switch exp := m["exp"].(type) { - case float64: - return verifyExp(int64(exp), cmp, req) - case json.Number: - v, _ := exp.Int64() + if v, ok := m.toInt64("exp"); ok { return verifyExp(v, cmp, req) } return !req @@ -56,11 +55,7 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { // Compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - switch iat := m["iat"].(type) { - case float64: - return verifyIat(int64(iat), cmp, req) - case json.Number: - v, _ := iat.Int64() + if v, ok := m.toInt64("iat"); ok { return verifyIat(v, cmp, req) } return !req @@ -76,16 +71,34 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { // Compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - switch nbf := m["nbf"].(type) { - case float64: - return verifyNbf(int64(nbf), cmp, req) - case json.Number: - v, _ := nbf.Int64() + if v, ok := m.toInt64("nbf"); ok { return verifyNbf(v, cmp, req) } + return !req } +func (m MapClaims) toInt64(claim string) (int64, bool) { + switch t := m[claim].(type) { + case float64: + return int64(t), true + case int64: + return t, true + case json.Number: + v, err := t.Int64() + if err == nil { + return v, true + } + vf, err := t.Float64() + if err != nil { + return 0, false + } + + return int64(vf), true + } + return 0, false +} + // Validates time based claims "exp, iat, nbf". // There is no accounting for clock skew. // As well, if any of the above claims are not in the token, it will still @@ -116,6 +129,35 @@ func (m MapClaims) Valid() error { return vErr } +func (m MapClaims) UnmarshalJSON(b []byte) error { + // A custom unmarshal is required in order to convert float64 integer values to int64. + // It does it on the first level of a map for relevant claims like "iat", "exp" and "nbf", + // but also applicable to any first level claim. + // + // This custom Unmarshal can be removed once this PR gets merged + // https://github.com/square/go-jose/pull/352 + d := jjson.NewDecoder(bytes.NewReader(b)) + mp := map[string]interface{}(m) + if err := d.Decode(&mp); err != nil { + return errorsx.WithStack(err) + } + + for k, v := range mp { + switch n := v.(type) { + case float64: + intv := int64(n) + // this checks that no precision gets lost + // and that the number fits into a int64 + if n != float64(intv) { + continue + } + m[k] = intv + } + } + + return nil +} + func verifyAud(aud []string, cmp string, required bool) bool { if len(aud) == 0 { return !required diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go index f050f0af0..a266dfce2 100644 --- a/token/jwt/token_test.go +++ b/token/jwt/token_test.go @@ -92,12 +92,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "basic expired", generate: &generate{ - claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, + claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, + claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100}, valid: false, errors: ValidationErrorExpired, }, @@ -106,12 +106,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "basic nbf", generate: &generate{ - claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, + claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, + claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100}, valid: false, errors: ValidationErrorNotValidYet, }, @@ -120,12 +120,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "expired and nbf", generate: &generate{ - claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, + claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, + claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100}, valid: false, errors: ValidationErrorNotValidYet | ValidationErrorExpired, }, @@ -259,12 +259,12 @@ func TestParser_Parse(t *testing.T) { given: given{ name: "used before issued", generate: &generate{ - claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)}, + claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500}, }, }, expected: expected{ keyFunc: defaultKeyFunc, - claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)}, + claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500}, valid: false, errors: ValidationErrorIssuedAt, },