Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use int64 type for claims with timestamps #600

Merged
merged 2 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
switch exp := claims["exp"].(type) {
case float64:
expiry = int64(exp)
case int64:
expiry = exp
case json.Number:
expiry, err = exp.Int64()
default:
Expand Down
6 changes: 3 additions & 3 deletions token/jwt/claims_id_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ func (c *IDTokenClaims) ToMap() map[string]interface{} {
}

if !c.IssuedAt.IsZero() {
ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype
ret["iat"] = c.IssuedAt.Unix()
} else {
delete(ret, "iat")
}

if !c.ExpiresAt.IsZero() {
ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype
ret["exp"] = c.ExpiresAt.Unix()
} else {
delete(ret, "exp")
}

if !c.RequestedAt.IsZero() {
ret["rat"] = float64(c.RequestedAt.Unix()) // jwt-go does not support int64 as datatype
ret["rat"] = c.RequestedAt.Unix()
} else {
delete(ret, "rat")
}
Expand Down
12 changes: 6 additions & 6 deletions token/jwt/claims_id_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func TestIDTokenClaimsToMap(t *testing.T) {
assert.Equal(t, map[string]interface{}{
"jti": idTokenClaims.JTI,
"sub": idTokenClaims.Subject,
"iat": float64(idTokenClaims.IssuedAt.Unix()),
"rat": float64(idTokenClaims.RequestedAt.Unix()),
"iat": idTokenClaims.IssuedAt.Unix(),
"rat": idTokenClaims.RequestedAt.Unix(),
"iss": idTokenClaims.Issuer,
"aud": idTokenClaims.Audience,
"exp": float64(idTokenClaims.ExpiresAt.Unix()),
"exp": idTokenClaims.ExpiresAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
"at_hash": idTokenClaims.AccessTokenHash,
Expand All @@ -79,11 +79,11 @@ func TestIDTokenClaimsToMap(t *testing.T) {
assert.Equal(t, map[string]interface{}{
"jti": idTokenClaims.JTI,
"sub": idTokenClaims.Subject,
"iat": float64(idTokenClaims.IssuedAt.Unix()),
"rat": float64(idTokenClaims.RequestedAt.Unix()),
"iat": idTokenClaims.IssuedAt.Unix(),
"rat": idTokenClaims.RequestedAt.Unix(),
"iss": idTokenClaims.Issuer,
"aud": idTokenClaims.Audience,
"exp": float64(idTokenClaims.ExpiresAt.Unix()),
"exp": idTokenClaims.ExpiresAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
"at_hash": idTokenClaims.AccessTokenHash,
Expand Down
46 changes: 21 additions & 25 deletions token/jwt/claims_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,19 @@ func (c *JWTClaims) ToMap() map[string]interface{} {
}

if !c.IssuedAt.IsZero() {
ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype
ret["iat"] = c.IssuedAt.Unix()
} else {
delete(ret, "iat")
}

if !c.NotBefore.IsZero() {
ret["nbf"] = float64(c.NotBefore.Unix()) // jwt-go does not support int64 as datatype
ret["nbf"] = c.NotBefore.Unix()
} else {
delete(ret, "nbf")
}

if !c.ExpiresAt.IsZero() {
ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype
ret["exp"] = c.ExpiresAt.Unix()
} else {
delete(ret, "exp")
}
Expand Down Expand Up @@ -183,38 +183,23 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) {
c.Audience = s
}
case "iat":
switch v.(type) {
case float64:
c.IssuedAt = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.IssuedAt = time.Unix(v.(int64), 0).UTC()
}
c.IssuedAt = toTime(v, c.IssuedAt)
case "nbf":
switch v.(type) {
case float64:
c.NotBefore = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.NotBefore = time.Unix(v.(int64), 0).UTC()
}
c.NotBefore = toTime(v, c.NotBefore)
case "exp":
switch v.(type) {
case float64:
c.ExpiresAt = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.ExpiresAt = time.Unix(v.(int64), 0).UTC()
}
c.ExpiresAt = toTime(v, c.ExpiresAt)
case "scp":
switch v.(type) {
switch s := v.(type) {
case []string:
c.Scope = v.([]string)
c.Scope = s
if c.ScopeField == JWTScopeFieldString {
c.ScopeField = JWTScopeFieldBoth
} else if c.ScopeField == JWTScopeFieldUnset {
c.ScopeField = JWTScopeFieldList
}
case []interface{}:
c.Scope = make([]string, len(v.([]interface{})))
for i, vi := range v.([]interface{}) {
c.Scope = make([]string, len(s))
for i, vi := range s {
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
if s, ok := vi.(string); ok {
c.Scope[i] = s
}
Expand All @@ -240,6 +225,17 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) {
}
}

func toTime(v interface{}, def time.Time) (t time.Time) {
t = def
switch a := v.(type) {
case float64:
t = time.Unix(int64(a), 0).UTC()
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
case int64:
t = time.Unix(a, 0).UTC()
}
return
}

// Add will add a key-value pair to the extra field
func (c *JWTClaims) Add(key string, value interface{}) {
if c.Extra == nil {
Expand Down
8 changes: 4 additions & 4 deletions token/jwt/claims_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ var jwtClaims = &JWTClaims{

var jwtClaimsMap = map[string]interface{}{
"sub": jwtClaims.Subject,
"iat": float64(jwtClaims.IssuedAt.Unix()),
"iat": jwtClaims.IssuedAt.Unix(),
"iss": jwtClaims.Issuer,
"nbf": float64(jwtClaims.NotBefore.Unix()),
"nbf": jwtClaims.NotBefore.Unix(),
"aud": jwtClaims.Audience,
"exp": float64(jwtClaims.ExpiresAt.Unix()),
"exp": jwtClaims.ExpiresAt.Unix(),
"jti": jwtClaims.JTI,
"scp": []string{"email", "offline"},
"foo": jwtClaims.Extra["foo"],
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestScopeFieldString(t *testing.T) {

func TestScopeFieldBoth(t *testing.T) {
jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth)
// Making a copy of jwtClaimsMap.
// Making a copy of jwtClaimsMap
jwtClaimsMapWithBoth := jwtClaims.ToMap()
jwtClaimsMapWithBoth["scope"] = "email offline"
assert.Equal(t, jwtClaimsMapWithBoth, map[string]interface{}(jwtClaimsWithBoth.ToMapClaims()))
Expand Down
74 changes: 58 additions & 16 deletions token/jwt/map_claims.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package jwt

import (
"bytes"
"crypto/subtle"
"encoding/json"
"errors"
"time"
// "fmt"

"github.com/ory/x/errorsx"
jjson "gopkg.in/square/go-jose.v2/json"
)

var TimeFunc = time.Now
Expand Down Expand Up @@ -43,11 +46,7 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
switch exp := m["exp"].(type) {
case float64:
return verifyExp(int64(exp), cmp, req)
case json.Number:
v, _ := exp.Int64()
if v, ok := m.toInt64("exp"); ok {
return verifyExp(v, cmp, req)
}
return !req
Expand All @@ -56,11 +55,7 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
switch iat := m["iat"].(type) {
case float64:
return verifyIat(int64(iat), cmp, req)
case json.Number:
v, _ := iat.Int64()
if v, ok := m.toInt64("iat"); ok {
return verifyIat(v, cmp, req)
}
return !req
Expand All @@ -76,16 +71,34 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
switch nbf := m["nbf"].(type) {
case float64:
return verifyNbf(int64(nbf), cmp, req)
case json.Number:
v, _ := nbf.Int64()
if v, ok := m.toInt64("nbf"); ok {
return verifyNbf(v, cmp, req)
}

return !req
}

func (m MapClaims) toInt64(claim string) (int64, bool) {
switch t := m[claim].(type) {
case float64:
return int64(t), true
case int64:
return t, true
case json.Number:
v, err := t.Int64()
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
return v, true
}
vf, err := t.Float64()
if err != nil {
return 0, false
}

return int64(vf), true
}
return 0, false
}

// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
Expand Down Expand Up @@ -116,6 +129,35 @@ func (m MapClaims) Valid() error {
return vErr
}

func (m MapClaims) UnmarshalJSON(b []byte) error {
// A custom unmarshal is required in order to convert float64 integer values to int64.
// It does it on the first level of a map for relevant claims like "iat", "exp" and "nbf",
// but also applicable to any first level claim.
//
// This custom Unmarshal can be removed once this PR gets merged
// https://github.com/square/go-jose/pull/352
d := jjson.NewDecoder(bytes.NewReader(b))
mp := map[string]interface{}(m)
if err := d.Decode(&mp); err != nil {
return errorsx.WithStack(err)
}

for k, v := range mp {
switch n := v.(type) {
case float64:
intv := int64(n)
// this checks that no precision gets lost
// and that the number fits into a int64
if n != float64(intv) {
continue
}
m[k] = intv
}
}

return nil
}

func verifyAud(aud []string, cmp string, required bool) bool {
if len(aud) == 0 {
return !required
Expand Down
16 changes: 8 additions & 8 deletions token/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "basic expired",
generate: &generate{
claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100},
valid: false,
errors: ValidationErrorExpired,
},
Expand All @@ -106,12 +106,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "basic nbf",
generate: &generate{
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100},
valid: false,
errors: ValidationErrorNotValidYet,
},
Expand All @@ -120,12 +120,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "expired and nbf",
generate: &generate{
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100},
valid: false,
errors: ValidationErrorNotValidYet | ValidationErrorExpired,
},
Expand Down Expand Up @@ -259,12 +259,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "used before issued",
generate: &generate{
claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)},
claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)},
claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500},
valid: false,
errors: ValidationErrorIssuedAt,
},
Expand Down