diff --git a/claims.go b/claims.go index f0228f02..66653d6e 100644 --- a/claims.go +++ b/claims.go @@ -16,13 +16,13 @@ type Claims interface { // https://tools.ietf.org/html/rfc7519#section-4.1 // See examples for how to use this with your own claim types type StandardClaims struct { - Audience string `json:"aud,omitempty"` - ExpiresAt int64 `json:"exp,omitempty"` - Id string `json:"jti,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` - Issuer string `json:"iss,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - Subject string `json:"sub,omitempty"` + Audience interface{} `json:"aud,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + Id string `json:"jti,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Issuer string `json:"iss,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` } // Validates time based claims "exp, iat, nbf". @@ -58,10 +58,29 @@ func (c StandardClaims) Valid() error { return vErr } -// Compares the aud claim against cmp. +// ExtractAudience extracts an array of audience values from the aud field. +func ExtractAudience(c *StandardClaims) []string { + switch c.Audience.(type) { + case nil: + return []string{} + case []interface{}: + auds := make([]string, len(c.Audience.([]interface{}))) + for i, value := range c.Audience.([]interface{}) { + auds[i] = value.(string) + } + return auds + case []string: + return c.Audience.([]string) + default: + return []string{c.Audience.(string)} + } +} + +// VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { - return verifyAud(c.Audience, cmp, req) + audiences := ExtractAudience(c) + return verifyAud(audiences, cmp, req) } // Compares the exp claim against cmp. @@ -90,13 +109,15 @@ func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { // ----- helpers -func verifyAud(aud string, cmp string, required bool) bool { - if aud == "" { +func verifyAud(auds []string, cmp string, required bool) bool { + if len(auds) == 0 { return !required - } - if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { - return true } else { + for _, aud := range auds { + if len(aud) == len(cmp) && subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { + return true + } + } return false } } diff --git a/claims_test.go b/claims_test.go new file mode 100644 index 00000000..26f73ee7 --- /dev/null +++ b/claims_test.go @@ -0,0 +1,114 @@ +package jwt + +import ( + "testing" +) + +// Test StandardClaims instances with an audience value populated in a string, []string and []interface{} +var audienceValue = "Aud" +var unmatchedAudienceValue = audienceValue + "Test" +var claimWithAudience = []StandardClaims{ + { + audienceValue, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, + { + []string{audienceValue, unmatchedAudienceValue}, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, + { + []interface{}{audienceValue, unmatchedAudienceValue}, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, +} + +// Test StandardClaims instances with no audiences within empty []string and []interface{} collections. +var claimWithoutAudience = []StandardClaims{ + { + nil, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, + { + []string{}, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, + { + []interface{}{}, + 123123, + "Id", + 12312, + "Issuer", + 12312, + "Subject", + }, +} + +func TestExtractAudienceWithAudienceValues(t *testing.T) { + for _, data := range claimWithAudience { + var aud = ExtractAudience(&data) + if len(aud) == 0 || aud[0] != audienceValue { + t.Errorf("The audience value was not extracted properly") + } + } +} + +func TestExtractAudience_WithoutAudienceValues(t *testing.T) { + for _, data := range claimWithoutAudience { + var aud = ExtractAudience(&data) + if len(aud) != 0 { + t.Errorf("An audience value should not have been extracted") + } + } +} + +var audWithValues = [][]string{ + []string{audienceValue}, + []string{"Aud1", "Aud2", audienceValue}, +} + +var audWithLackingOriginalValue = [][]string{ + []string{}, + []string{audienceValue + "1"}, + []string{"Aud1", "Aud2", audienceValue + "1"}, +} + +func TestVerifyAud_ShouldVerifyExists(t *testing.T) { + for _, data := range audWithValues { + if !verifyAud(data, audienceValue, true) { + t.Errorf("The audience value was not verified properly") + } + } +} + +func TestVerifyAud_ShouldVerifyDoesNotExist(t *testing.T) { + for _, data := range audWithValues { + if !verifyAud(data, audienceValue, true) { + t.Errorf("The audience value was not verified properly") + } + } +} diff --git a/map_claims.go b/map_claims.go index 291213c4..643e18f3 100644 --- a/map_claims.go +++ b/map_claims.go @@ -13,8 +13,18 @@ type MapClaims map[string]interface{} // Compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - aud, _ := m["aud"].(string) - return verifyAud(aud, cmp, req) + switch aud := m["aud"].(type) { + case []string: + return verifyAud(aud, cmp, req) + case []interface{}: + auds := make([]string, len(aud)) + for i, value := range aud { + auds[i] = value.(string) + } + return verifyAud(auds, cmp, req) + default: + return verifyAud([]string{aud.(string)}, cmp, req) + } } // Compares the exp claim against cmp. diff --git a/map_claims_tests.go b/map_claims_tests.go new file mode 100644 index 00000000..39ad4503 --- /dev/null +++ b/map_claims_tests.go @@ -0,0 +1,46 @@ +package jwt + +import ( + "testing" +) + +var audFixedValue = "Aud" +var audClaimsMapsWithValues = []MapClaims{ + { + "aud": audFixedValue, + }, + { + "aud": []string{audFixedValue}, + }, + { + "aud": []interface{}{audFixedValue}, + }, +} + +var audClaimsMapsWithoutValues = []MapClaims{ + {}, + { + "aud": []string{}, + }, + { + "aud": []interface{}{}, + }, +} + +// Verifies that for every form of the "aud" field, the audFixedValue is always verifiable +func TestVerifyAudienceWithVerifiableValues(t *testing.T) { + for _, data := range audClaimsMapsWithValues { + if !data.VerifyAudience(audFixedValue, true) { + t.Errorf("The audience value was not extracted properly") + } + } +} + +// Verifies that for every empty form of the "aud" field, the audFixedValue cannot be verified +func TestVerifyAudienceWithoutVerifiableValues(t *testing.T) { + for _, data := range audClaimsMapsWithoutValues { + if data.VerifyAudience(audFixedValue, true) { + t.Errorf("The audience should not verify") + } + } +}