Skip to content

Commit

Permalink
feat: use int64 type for claims with timestamps (#600)
Browse files Browse the repository at this point in the history
Co-authored-by: Nestor <nesterran@gmail.com>
  • Loading branch information
nestorvw and narg95 authored May 28, 2021
1 parent 5def9a4 commit c370994
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 62 deletions.
2 changes: 2 additions & 0 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
switch exp := claims["exp"].(type) {
case float64:
expiry = int64(exp)
case int64:
expiry = exp
case json.Number:
expiry, err = exp.Int64()
default:
Expand Down
6 changes: 3 additions & 3 deletions token/jwt/claims_id_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ func (c *IDTokenClaims) ToMap() map[string]interface{} {
}

if !c.IssuedAt.IsZero() {
ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype
ret["iat"] = c.IssuedAt.Unix()
} else {
delete(ret, "iat")
}

if !c.ExpiresAt.IsZero() {
ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype
ret["exp"] = c.ExpiresAt.Unix()
} else {
delete(ret, "exp")
}

if !c.RequestedAt.IsZero() {
ret["rat"] = float64(c.RequestedAt.Unix()) // jwt-go does not support int64 as datatype
ret["rat"] = c.RequestedAt.Unix()
} else {
delete(ret, "rat")
}
Expand Down
12 changes: 6 additions & 6 deletions token/jwt/claims_id_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func TestIDTokenClaimsToMap(t *testing.T) {
assert.Equal(t, map[string]interface{}{
"jti": idTokenClaims.JTI,
"sub": idTokenClaims.Subject,
"iat": float64(idTokenClaims.IssuedAt.Unix()),
"rat": float64(idTokenClaims.RequestedAt.Unix()),
"iat": idTokenClaims.IssuedAt.Unix(),
"rat": idTokenClaims.RequestedAt.Unix(),
"iss": idTokenClaims.Issuer,
"aud": idTokenClaims.Audience,
"exp": float64(idTokenClaims.ExpiresAt.Unix()),
"exp": idTokenClaims.ExpiresAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
"at_hash": idTokenClaims.AccessTokenHash,
Expand All @@ -79,11 +79,11 @@ func TestIDTokenClaimsToMap(t *testing.T) {
assert.Equal(t, map[string]interface{}{
"jti": idTokenClaims.JTI,
"sub": idTokenClaims.Subject,
"iat": float64(idTokenClaims.IssuedAt.Unix()),
"rat": float64(idTokenClaims.RequestedAt.Unix()),
"iat": idTokenClaims.IssuedAt.Unix(),
"rat": idTokenClaims.RequestedAt.Unix(),
"iss": idTokenClaims.Issuer,
"aud": idTokenClaims.Audience,
"exp": float64(idTokenClaims.ExpiresAt.Unix()),
"exp": idTokenClaims.ExpiresAt.Unix(),
"foo": idTokenClaims.Extra["foo"],
"baz": idTokenClaims.Extra["baz"],
"at_hash": idTokenClaims.AccessTokenHash,
Expand Down
46 changes: 21 additions & 25 deletions token/jwt/claims_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,19 @@ func (c *JWTClaims) ToMap() map[string]interface{} {
}

if !c.IssuedAt.IsZero() {
ret["iat"] = float64(c.IssuedAt.Unix()) // jwt-go does not support int64 as datatype
ret["iat"] = c.IssuedAt.Unix()
} else {
delete(ret, "iat")
}

if !c.NotBefore.IsZero() {
ret["nbf"] = float64(c.NotBefore.Unix()) // jwt-go does not support int64 as datatype
ret["nbf"] = c.NotBefore.Unix()
} else {
delete(ret, "nbf")
}

if !c.ExpiresAt.IsZero() {
ret["exp"] = float64(c.ExpiresAt.Unix()) // jwt-go does not support int64 as datatype
ret["exp"] = c.ExpiresAt.Unix()
} else {
delete(ret, "exp")
}
Expand Down Expand Up @@ -183,38 +183,23 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) {
c.Audience = s
}
case "iat":
switch v.(type) {
case float64:
c.IssuedAt = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.IssuedAt = time.Unix(v.(int64), 0).UTC()
}
c.IssuedAt = toTime(v, c.IssuedAt)
case "nbf":
switch v.(type) {
case float64:
c.NotBefore = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.NotBefore = time.Unix(v.(int64), 0).UTC()
}
c.NotBefore = toTime(v, c.NotBefore)
case "exp":
switch v.(type) {
case float64:
c.ExpiresAt = time.Unix(int64(v.(float64)), 0).UTC()
case int64:
c.ExpiresAt = time.Unix(v.(int64), 0).UTC()
}
c.ExpiresAt = toTime(v, c.ExpiresAt)
case "scp":
switch v.(type) {
switch s := v.(type) {
case []string:
c.Scope = v.([]string)
c.Scope = s
if c.ScopeField == JWTScopeFieldString {
c.ScopeField = JWTScopeFieldBoth
} else if c.ScopeField == JWTScopeFieldUnset {
c.ScopeField = JWTScopeFieldList
}
case []interface{}:
c.Scope = make([]string, len(v.([]interface{})))
for i, vi := range v.([]interface{}) {
c.Scope = make([]string, len(s))
for i, vi := range s {
if s, ok := vi.(string); ok {
c.Scope[i] = s
}
Expand All @@ -240,6 +225,17 @@ func (c *JWTClaims) FromMap(m map[string]interface{}) {
}
}

func toTime(v interface{}, def time.Time) (t time.Time) {
t = def
switch a := v.(type) {
case float64:
t = time.Unix(int64(a), 0).UTC()
case int64:
t = time.Unix(a, 0).UTC()
}
return
}

// Add will add a key-value pair to the extra field
func (c *JWTClaims) Add(key string, value interface{}) {
if c.Extra == nil {
Expand Down
8 changes: 4 additions & 4 deletions token/jwt/claims_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ var jwtClaims = &JWTClaims{

var jwtClaimsMap = map[string]interface{}{
"sub": jwtClaims.Subject,
"iat": float64(jwtClaims.IssuedAt.Unix()),
"iat": jwtClaims.IssuedAt.Unix(),
"iss": jwtClaims.Issuer,
"nbf": float64(jwtClaims.NotBefore.Unix()),
"nbf": jwtClaims.NotBefore.Unix(),
"aud": jwtClaims.Audience,
"exp": float64(jwtClaims.ExpiresAt.Unix()),
"exp": jwtClaims.ExpiresAt.Unix(),
"jti": jwtClaims.JTI,
"scp": []string{"email", "offline"},
"foo": jwtClaims.Extra["foo"],
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestScopeFieldString(t *testing.T) {

func TestScopeFieldBoth(t *testing.T) {
jwtClaimsWithBoth := jwtClaims.WithScopeField(JWTScopeFieldBoth)
// Making a copy of jwtClaimsMap.
// Making a copy of jwtClaimsMap
jwtClaimsMapWithBoth := jwtClaims.ToMap()
jwtClaimsMapWithBoth["scope"] = "email offline"
assert.Equal(t, jwtClaimsMapWithBoth, map[string]interface{}(jwtClaimsWithBoth.ToMapClaims()))
Expand Down
74 changes: 58 additions & 16 deletions token/jwt/map_claims.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package jwt

import (
"bytes"
"crypto/subtle"
"encoding/json"
"errors"
"time"
// "fmt"

"github.com/ory/x/errorsx"
jjson "gopkg.in/square/go-jose.v2/json"
)

var TimeFunc = time.Now
Expand Down Expand Up @@ -43,11 +46,7 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
switch exp := m["exp"].(type) {
case float64:
return verifyExp(int64(exp), cmp, req)
case json.Number:
v, _ := exp.Int64()
if v, ok := m.toInt64("exp"); ok {
return verifyExp(v, cmp, req)
}
return !req
Expand All @@ -56,11 +55,7 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
switch iat := m["iat"].(type) {
case float64:
return verifyIat(int64(iat), cmp, req)
case json.Number:
v, _ := iat.Int64()
if v, ok := m.toInt64("iat"); ok {
return verifyIat(v, cmp, req)
}
return !req
Expand All @@ -76,16 +71,34 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
switch nbf := m["nbf"].(type) {
case float64:
return verifyNbf(int64(nbf), cmp, req)
case json.Number:
v, _ := nbf.Int64()
if v, ok := m.toInt64("nbf"); ok {
return verifyNbf(v, cmp, req)
}

return !req
}

func (m MapClaims) toInt64(claim string) (int64, bool) {
switch t := m[claim].(type) {
case float64:
return int64(t), true
case int64:
return t, true
case json.Number:
v, err := t.Int64()
if err == nil {
return v, true
}
vf, err := t.Float64()
if err != nil {
return 0, false
}

return int64(vf), true
}
return 0, false
}

// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
Expand Down Expand Up @@ -116,6 +129,35 @@ func (m MapClaims) Valid() error {
return vErr
}

func (m MapClaims) UnmarshalJSON(b []byte) error {
// A custom unmarshal is required in order to convert float64 integer values to int64.
// It does it on the first level of a map for relevant claims like "iat", "exp" and "nbf",
// but also applicable to any first level claim.
//
// This custom Unmarshal can be removed once this PR gets merged
// https://github.com/square/go-jose/pull/352
d := jjson.NewDecoder(bytes.NewReader(b))
mp := map[string]interface{}(m)
if err := d.Decode(&mp); err != nil {
return errorsx.WithStack(err)
}

for k, v := range mp {
switch n := v.(type) {
case float64:
intv := int64(n)
// this checks that no precision gets lost
// and that the number fits into a int64
if n != float64(intv) {
continue
}
m[k] = intv
}
}

return nil
}

func verifyAud(aud []string, cmp string, required bool) bool {
if len(aud) == 0 {
return !required
Expand Down
16 changes: 8 additions & 8 deletions token/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "basic expired",
generate: &generate{
claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "exp": time.Now().Unix() - 100},
valid: false,
errors: ValidationErrorExpired,
},
Expand All @@ -106,12 +106,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "basic nbf",
generate: &generate{
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100},
valid: false,
errors: ValidationErrorNotValidYet,
},
Expand All @@ -120,12 +120,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "expired and nbf",
generate: &generate{
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
claims: MapClaims{"foo": "bar", "nbf": time.Now().Unix() + 100, "exp": time.Now().Unix() - 100},
valid: false,
errors: ValidationErrorNotValidYet | ValidationErrorExpired,
},
Expand Down Expand Up @@ -259,12 +259,12 @@ func TestParser_Parse(t *testing.T) {
given: given{
name: "used before issued",
generate: &generate{
claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)},
claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500},
},
},
expected: expected{
keyFunc: defaultKeyFunc,
claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)},
claims: MapClaims{"foo": "bar", "iat": time.Now().Unix() + 500},
valid: false,
errors: ValidationErrorIssuedAt,
},
Expand Down

0 comments on commit c370994

Please sign in to comment.