diff --git a/pkg/doc/did/doc.go b/pkg/doc/did/doc.go index 1b7d5b46f..f41801483 100644 --- a/pkg/doc/did/doc.go +++ b/pkg/doc/did/doc.go @@ -166,9 +166,13 @@ func ParseDIDURL(didURL string) (*DIDURL, error) { return ret, nil } +// Context represents JSON-LD representation-specific DID-core @context, which +// must be either a string, or a list containing maps and/or strings. +type Context interface{} + // DocResolution did resolution. type DocResolution struct { - Context []string + Context Context DIDDocument *Doc DocumentMetadata *DocumentMetadata } @@ -224,7 +228,7 @@ type DocumentMetadata struct { } type rawDocResolution struct { - Context interface{} `json:"@context"` + Context Context `json:"@context"` DIDDocument json.RawMessage `json:"didDocument,omitempty"` DocumentMetadata json.RawMessage `json:"didDocumentMetadata,omitempty"` } @@ -261,7 +265,7 @@ func ParseDocumentResolution(data []byte) (*DocResolution, error) { // Doc DID Document definition. type Doc struct { - Context []string + Context Context ID string AlsoKnownAs []string VerificationMethod []VerificationMethod @@ -427,7 +431,7 @@ func NewReferencedVerification(vm *VerificationMethod, r VerificationRelationshi } type rawDoc struct { - Context interface{} `json:"@context,omitempty"` + Context Context `json:"@context,omitempty"` ID string `json:"id,omitempty"` AlsoKnownAs []interface{} `json:"alsoKnownAs,omitempty"` VerificationMethod []map[string]interface{} `json:"verificationMethod,omitempty"` @@ -507,7 +511,9 @@ func ParseDocument(data []byte) (*Doc, error) { verificationMethod = raw.VerificationMethod } - vm, err := populateVerificationMethod(context[0], doc.ID, baseURI, verificationMethod) + schema, _ := ContextPeekString(context) + + vm, err := populateVerificationMethod(schema, doc.ID, baseURI, verificationMethod) if err != nil { return nil, fmt.Errorf("populate verification method failed: %w", err) } @@ -519,7 +525,7 @@ func ParseDocument(data []byte) (*Doc, error) { return nil, err } - proofs, err := populateProofs(context[0], doc.ID, baseURI, raw.Proof) + proofs, err := populateProofs(schema, doc.ID, baseURI, raw.Proof) if err != nil { return nil, fmt.Errorf("populate proofs failed: %w", err) } @@ -530,18 +536,10 @@ func ParseDocument(data []byte) (*Doc, error) { } func requiresLegacyHandling(raw *rawDoc) bool { - context, _ := parseContext(raw.Context) - - for _, ctx := range context { - if ctx == ContextV1Old { - // aca-py issue: https://github.com/hyperledger/aries-cloudagent-python/issues/1048 - // old v1 context is (currently) only used by projects like aca-py that - // have not fully updated to latest did spec for aip2.0 - return true - } - } - - return false + // aca-py issue: https://github.com/hyperledger/aries-cloudagent-python/issues/1048 + // old v1 context is (currently) only used by projects like aca-py that + // have not fully updated to latest did spec for aip2.0 + return ContextContainsString(raw.Context, ContextV1Old) } func populateVerificationRelationships(doc *Doc, raw *rawDoc) error { @@ -758,7 +756,7 @@ func getVerification(doc *Doc, rawVerification interface{}, relationship VerificationRelationship) ([]Verification, error) { // context, docID string vm := doc.VerificationMethod - context := doc.Context[0] + context, _ := ContextPeekString(doc.Context) keyID, keyIDExist := rawVerification.(string) if keyIDExist { @@ -957,42 +955,45 @@ func decodeVMJwk(jwkMap map[string]interface{}, vm *VerificationMethod) error { return nil } -func parseContext(context interface{}) ([]string, string) { +func parseContext(context Context) (Context, string) { + context = ContextCopy(context) + switch ctx := context.(type) { + case string, []string: + return ctx, "" case []interface{}: - var context []string + // copy slice to prevent unexpected mutation + var newContext []interface{} var base string for _, v := range ctx { switch value := v.(type) { case string: - context = append(context, value) + newContext = append(newContext, value) case map[string]interface{}: - baseValue, ok := value["@base"].(string) - if ok { + // preserve base value if it exists and is a string + if baseValue, ok := value["@base"].(string); ok { base = baseValue } + + delete(value, "@base") + + if len(value) > 0 { + newContext = append(newContext, value) + } } } - return context, base - case []string: - return ctx, "" - case interface{}: - return []string{context.(string)}, "" + return ContextCleanup(newContext), base } - return []string{""}, "" + return "", "" } func (r *rawDoc) schemaLoader() gojsonschema.JSONLoader { - context, _ := parseContext(r.Context) - if len(context) == 0 { - return schemaLoaderV1 - } - - switch context[0] { + context, _ := ContextPeekString(r.Context) + switch context { case contextV011: return schemaLoaderV011 case contextV12019: @@ -1111,10 +1112,9 @@ func (docResolution *DocResolution) JSONBytes() ([]byte, error) { // JSONBytes converts document to json bytes. func (doc *Doc) JSONBytes() ([]byte, error) { - context := ContextV1 - - if len(doc.Context) > 0 { - context = doc.Context[0] + context, ok := ContextPeekString(doc.Context) + if !ok { + context = ContextV1 } aka := populateRawAlsoKnownAs(doc.AlsoKnownAs) @@ -1172,14 +1172,23 @@ func (doc *Doc) JSONBytes() ([]byte, error) { return byteDoc, nil } -func contextWithBase(doc *Doc) []interface{} { +func contextWithBase(doc *Doc) Context { baseObject := make(map[string]interface{}) baseObject["@base"] = doc.processingMeta.baseURI m := make([]interface{}, 0) - for _, v := range doc.Context { - m = append(m, v) + switch ctx := doc.Context.(type) { + case string: + m = append(m, ctx) + case []string: + for _, item := range ctx { + m = append(m, item) + } + case []interface{}: + if len(ctx) > 0 { + m = append(m, ctx...) + } } m = append(m, baseObject) diff --git a/pkg/doc/did/doc_test.go b/pkg/doc/did/doc_test.go index e73379033..0c5bfb2b8 100644 --- a/pkg/doc/did/doc_test.go +++ b/pkg/doc/did/doc_test.go @@ -84,7 +84,10 @@ func TestValidWithDocBase(t *testing.T) { doc, err := ParseDocument([]byte(d)) require.NoError(t, err) require.NotNil(t, doc) - require.Contains(t, doc.Context[0], "https://www.w3.org/ns/did/v") + + context, ok := doc.Context.([]string) + require.True(t, ok) + require.Contains(t, context[0], "https://www.w3.org/ns/did/v") // test doc id require.Equal(t, doc.ID, "did:example:123456789abcdefghi") @@ -176,8 +179,7 @@ func TestDocResolution(t *testing.T) { d, err := ParseDocumentResolution([]byte(validDocResolution)) require.NoError(t, err) - require.Equal(t, 1, len(d.Context)) - require.Equal(t, "https://w3id.org/did-resolution/v1", d.Context[0]) + require.Equal(t, "https://w3id.org/did-resolution/v1", d.Context.(string)) require.Equal(t, "did:example:21tDAKCERh95uGgKbJNHYp", d.DIDDocument.ID) require.Equal(t, 1, len(d.DIDDocument.AlsoKnownAs)) require.Equal(t, "did:example:123", d.DIDDocument.AlsoKnownAs[0]) @@ -190,8 +192,7 @@ func TestDocResolution(t *testing.T) { d, err = ParseDocumentResolution(bytes) require.NoError(t, err) - require.Equal(t, 1, len(d.Context)) - require.Equal(t, "https://w3id.org/did-resolution/v1", d.Context[0]) + require.Equal(t, "https://w3id.org/did-resolution/v1", d.Context.(string)) require.Equal(t, "did:example:21tDAKCERh95uGgKbJNHYp", d.DIDDocument.ID) require.Equal(t, 1, len(d.DIDDocument.AlsoKnownAs)) require.Equal(t, "did:example:123", d.DIDDocument.AlsoKnownAs[0]) @@ -206,13 +207,156 @@ func TestDocResolution(t *testing.T) { }) } +func TestContextVariations(t *testing.T) { + var ( + Base = "did:example:123456789abcdefghi" + Vocab = "https://www.w3.org/ns/did/#" + ContextDIDv1 = "https://www.w3.org/ns/did/v1" + ContextTraceability = "https://w3id.org/traceability/v1" + ContextBase = map[string]interface{}{"@base": Base} + ContextVocab = map[string]interface{}{"@vocab": Vocab} + ContextMixed = map[string]interface{}{"@base": Base, "@vocab": Vocab} + ) + + tests := map[string]struct { + input Context + context Context + base string + }{ + "'string'": { + input: ContextDIDv1, + context: ContextDIDv1, + base: "", + }, + "'string' empty": { + input: "", + context: "", + base: "", + }, + "'[]string' empty": { + input: []string{""}, + context: []string{""}, + base: "", + }, + "'[]string' single": { + input: []string{ContextDIDv1}, + context: []string{ContextDIDv1}, + base: "", + }, + "'[]string' multiple": { + input: []string{ContextDIDv1, ContextTraceability}, + context: []string{ContextDIDv1, ContextTraceability}, + base: "", + }, + "'[]interface{}' empty string": { + input: []interface{}{""}, + context: []string{""}, + base: "", + }, + "'[]interface{}' single string": { + input: []interface{}{ContextDIDv1}, + context: []string{ContextDIDv1}, + base: "", + }, + "'[]interface{}' multiple string": { + input: []interface{}{ContextDIDv1, ContextTraceability}, + context: []string{ContextDIDv1, ContextTraceability}, + base: "", + }, + "'[]interface{}' string + base": { + input: []interface{}{ContextDIDv1, ContextBase}, + context: []string{ContextDIDv1}, + base: Base, + }, + "'[]interface{}' string + vocab": { + input: []interface{}{ContextDIDv1, ContextVocab}, + context: []interface{}{ContextDIDv1, ContextVocab}, + base: "", + }, + "'[]interface{}' string + vocab + base": { + input: []interface{}{ContextDIDv1, ContextVocab, ContextBase}, + context: []interface{}{ContextDIDv1, ContextVocab}, + base: Base, + }, + "'[]interface{}' string + mixed": { + input: []interface{}{ContextDIDv1, ContextMixed}, + context: []interface{}{ContextDIDv1, ContextVocab}, + base: Base, + }, + "'[]interface{}' base": { + input: []interface{}{ContextBase}, + context: "", + base: Base, + }, + "'[]interface{}' vocab": { + input: []interface{}{ContextVocab}, + context: []interface{}{ContextVocab}, + base: "", + }, + "'[]interface{}' base + vocab": { + input: []interface{}{ContextBase, ContextVocab}, + context: []interface{}{ContextVocab}, + base: Base, + }, + "'[]interface{}' mixed": { + input: []interface{}{ContextMixed}, + context: []interface{}{ContextVocab}, + base: Base, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + context, base := parseContext(tc.input) + require.Equal(t, tc.context, context) + require.Equal(t, tc.base, base) + }) + } +} + +func TestContextMutationPrevention(t *testing.T) { + t.Run("string array mutation", func(t *testing.T) { + oldContext := []string{"stringval"} + newContext, _ := parseContext(oldContext) + + a0, ok := newContext.([]string) + require.True(t, ok) + + a0[0] = "mutated_stringval" + require.Equal(t, []string{"stringval"}, oldContext) + }) + + t.Run("map element mutation (@base)", func(t *testing.T) { + oldContext := []interface{}{map[string]interface{}{"@base": "baseval"}} + _, _ = parseContext(oldContext) + require.Equal(t, []interface{}{map[string]interface{}{"@base": "baseval"}}, oldContext) + }) + + t.Run("map element mutation (not @base)", func(t *testing.T) { + oldContext := []interface{}{map[string]interface{}{"@key": "keyval"}} + newContext, _ := parseContext(oldContext) + + a0, ok := newContext.([]interface{}) + require.True(t, ok) + + m0, ok := a0[0].(map[string]interface{}) + require.True(t, ok) + + m0["@key"] = "keyval_mutated" + require.Equal(t, []interface{}{map[string]interface{}{"@key": "keyval"}}, oldContext) + }) +} + func TestValid(t *testing.T) { docs := []string{validDoc} for _, d := range docs { doc, err := ParseDocument([]byte(d)) require.NoError(t, err) require.NotNil(t, doc) - require.Contains(t, doc.Context[0], "https://www.w3.org/ns/did/v") + + context, ok := doc.Context.([]string) + require.True(t, ok) + require.Contains(t, context[0], "https://www.w3.org/ns/did/v") // test doc id require.Equal(t, doc.ID, "did:example:21tDAKCERh95uGgKbJNHYp") diff --git a/pkg/doc/did/helpers.go b/pkg/doc/did/helpers.go index 2a157fa6f..66ad6adfa 100644 --- a/pkg/doc/did/helpers.go +++ b/pkg/doc/did/helpers.go @@ -5,6 +5,132 @@ SPDX-License-Identifier: Apache-2.0 package did +// ContextCleanup performs non-intrusive cleanup of the given context by +// converting `[]string(nil)` and `[]interface{}(nil)` to the empty string, and +// converting `[]interface{}` to `[]string` if it contains only string values. +// This will NOT change string arrays into single strings, even when they contain +// only a single string. +func ContextCleanup(context Context) Context { + context = ContextCopy(context) + + switch ctx := context.(type) { + case string: + return ctx + case []string: + if len(ctx) == 0 { + return []string{""} + } + + return ctx + case []interface{}: + if len(ctx) == 0 { + return "" + } + + var newContext []string + + for _, item := range ctx { + strVal, ok := item.(string) + if !ok { + return ctx + } + + newContext = append(newContext, strVal) + } + + return newContext + } + + return context +} + +// ContextCopy create a deep copy of the given context. This is used to prevent +// unintentional mutations of `Context` instances which are passed to functions +// that modify and return updated values, e.g., `parseContext()`. +func ContextCopy(context Context) Context { + switch ctx := context.(type) { + case string: + return ctx + case []string: + var newContext []string + newContext = append(newContext, ctx...) + + return newContext + case []interface{}: + var newContext []interface{} + + for _, v := range ctx { + switch value := v.(type) { + case string: + newContext = append(newContext, value) + case map[string]interface{}: + newValue := map[string]interface{}{} + for k, v := range value { + newValue[k] = v + } + + newContext = append(newContext, newValue) + } + } + + return newContext + } + + return context +} + +// ContextPeekString returns the first string element in `context`, which +// identifies the DID JSON-LD schema in use. This is generally useful to +// branch based on the version of the DID schema. +func ContextPeekString(context Context) (string, bool) { + switch ctx := context.(type) { + case string: + if len(ctx) > 0 { + return ctx, true + } + case []string: + if len(ctx) > 0 { + return ctx[0], true + } + case []interface{}: + if len(ctx) > 0 { + if strval, ok := ctx[0].(string); ok { + return strval, true + } + } + } + + return "", false +} + +// ContextContainsString returns true if the given Context contains the given +// context string. Strings nested inside maps are not checked. +func ContextContainsString(context Context, contextString string) bool { + // Extract all string values from context + var have []string + switch ctx := context.(type) { + case string: + have = append(have, ctx) + case []string: + have = append(have, ctx...) + case []interface{}: + for _, val := range ctx { + if strval, ok := val.(string); ok { + have = append(have, strval) + } + } + } + + // Look for desired string in extracted values + for _, item := range have { + if item == contextString { + return true + } + } + + return false +} + // LookupService returns the service from the given DIDDoc matching the given service type. func LookupService(didDoc *Doc, serviceType string) (*Service, bool) { const notFound = -1 diff --git a/pkg/doc/did/helpers_test.go b/pkg/doc/did/helpers_test.go index 365788840..3c4307ac5 100644 --- a/pkg/doc/did/helpers_test.go +++ b/pkg/doc/did/helpers_test.go @@ -6,6 +6,7 @@ SPDX-License-Identifier: Apache-2.0 package did_test import ( + "reflect" "testing" "github.com/stretchr/testify/require" @@ -14,6 +15,148 @@ import ( mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/mock/diddoc" ) +func TestContextCleanup(t *testing.T) { + t.Run("string", func(t *testing.T) { + var c0 Context = "stringval" + var c1 Context = ContextCleanup(c0) + require.Equal(t, c0, c1) + }) + + t.Run("[]string", func(t *testing.T) { + var c0 Context = []string{"stringval"} + var c1 Context = ContextCleanup(c0) + require.Equal(t, c0, c1) + }) + + t.Run("[]string empty", func(t *testing.T) { + var c0 Context = []string{""} + var c1 Context = ContextCleanup(c0) + require.Equal(t, c0, c1) + }) + + t.Run("[]string nil", func(t *testing.T) { + var c0 Context = []string{} + var c1 Context = ContextCleanup(c0) + require.Equal(t, []string{""}, c1) + }) + + t.Run("[]interface{} nil value", func(t *testing.T) { + var c0 Context = []interface{}{} + var c1 Context = ContextCleanup(c0) + require.Equal(t, "", c1) + }) + + t.Run("[]interface{} all strings", func(t *testing.T) { + var c0 Context = []interface{}{"alpha", "beta"} + var c1 Context = ContextCleanup(c0) + require.Equal(t, []string{"alpha", "beta"}, c1) + }) + + t.Run("[]interface{} with map", func(t *testing.T) { + var c0 Context = []interface{}{map[string]interface{}{"@key": "value"}} + var c1 Context = ContextCleanup(c0) + require.Equal(t, c0, c1) + }) +} + +func TestContextCopy(t *testing.T) { + t.Run("string", func(t *testing.T) { + var c0 Context = "stringval" + var c1 Context = ContextCopy(c0) + require.Equal(t, c0, c1) + }) + + t.Run("[]string", func(t *testing.T) { + var c0 Context = []string{"stringval"} + var c1 Context = ContextCopy(c0) + require.Equal(t, c0, c1) + + p0, p1 := reflect.ValueOf(c0).Pointer(), reflect.ValueOf(c1).Pointer() + require.NotEqual(t, p0, p1, "slices should not share pointers") + }) + + t.Run("[]interface{}", func(t *testing.T) { + var c0 Context = []interface{}{map[string]interface{}{"@key": "value"}} + var c1 Context = ContextCopy(c0) + require.Equal(t, c0, c1) + + p0, p1 := reflect.ValueOf(c0).Pointer(), reflect.ValueOf(c1).Pointer() + require.NotEqual(t, p0, p1, "slices should not share pointers") + + a0, ok := c0.([]interface{}) + require.True(t, ok) + + a1, ok := c1.([]interface{}) + require.True(t, ok) + + p0, p1 = reflect.ValueOf(a0[0]).Pointer(), reflect.ValueOf(a1[0]).Pointer() + require.NotEqual(t, p0, p1, "maps should not share pointers") + }) +} + +func TestContextPeekString(t *testing.T) { + const ( + DoNotWant = "ContextDoNotWant" + Want = "ContextWant" + ) + + tests := map[string]struct { + schema string + ok bool + input Context + }{ + "present in 'string'": {schema: Want, ok: true, input: Want}, + "present in '[]string' (single)": {schema: Want, ok: true, input: []string{Want}}, + "present in '[]string' (multiple)": {schema: Want, ok: true, input: []string{Want, DoNotWant}}, + "present in '[]interface{}' (single)": {schema: Want, ok: true, input: []interface{}{Want}}, + "present in '[]interface{}' (multiple)": {schema: Want, ok: true, input: []interface{}{Want, DoNotWant}}, + "not present in 'string'": {schema: "", ok: false, input: ""}, + "not present in '[]string'": {schema: "", ok: false, input: []string{}}, + "not present in '[]interface{}'": {schema: "", ok: false, input: []interface{}{}}, + "context is nil": {schema: "", ok: false, input: nil}, + "context is invalid": {schema: "", ok: false, input: 42}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + schema, ok := ContextPeekString(tc.input) + require.Equal(t, tc.ok, ok) + require.Equal(t, tc.schema, schema) + }) + } +} + +func TestContextContainsString(t *testing.T) { + const ( + DoNotWant = "ContextDoNotWant" + Want = "ContextWant" + ) + + tests := map[string]struct { + ok bool + input Context + }{ + "present in 'string'": {ok: true, input: Want}, + "present in '[]string' (first)": {ok: true, input: []string{Want}}, + "present in '[]string' (not first)": {ok: true, input: []string{DoNotWant, Want}}, + "present in '[]interface{}' (first)": {ok: true, input: []interface{}{Want}}, + "present in '[]interface{}' (not first)": {ok: true, input: []interface{}{DoNotWant, Want}}, + "present in '[]interface{}' (map)": {ok: false, input: []interface{}{map[string]interface{}{"k": Want}}}, + "not present in 'string'": {ok: false, input: DoNotWant}, + "not present in '[]string'": {ok: false, input: []string{DoNotWant}}, + "not present in '[]interface{}'": {ok: false, input: []interface{}{DoNotWant}}, + "context is nil": {ok: false, input: nil}, + "context is invalid": {ok: false, input: 42}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ok := ContextContainsString(tc.input, Want) + require.Equal(t, tc.ok, ok) + }) + } +} + func TestGetRecipientKeys(t *testing.T) { t.Run("successfully getting recipient keys", func(t *testing.T) { didDoc := mockdiddoc.GetMockDIDDoc(t, false) diff --git a/pkg/vdr/key/creator_test.go b/pkg/vdr/key/creator_test.go index cf597ba2f..8215c9fc4 100644 --- a/pkg/vdr/key/creator_test.go +++ b/pkg/vdr/key/creator_test.go @@ -258,8 +258,20 @@ func assertBase58Doc(t *testing.T, doc *did.Doc, didKey, didKeyID, didKeyType, p func assertDualBase58Doc(t *testing.T, doc *did.Doc, didKey, didKeyID, didKeyType, pubKeyBase58, agreementKeyID, keyAgreementType, keyAgreementBase58 string) { + var context string + switch ctx := doc.Context.(type) { + case string: + context = ctx + case []string: + context = ctx[0] + case []interface{}: + var ok bool + context, ok = ctx[0].(string) + require.True(t, ok) + } + // validate @context - require.Equal(t, schemaDIDV1, doc.Context[0]) + require.Equal(t, schemaDIDV1, context) // validate id require.Equal(t, didKey, doc.ID) @@ -332,8 +344,20 @@ func createVerificationMethodFromXAndY(t *testing.T, didKeyID, didKey string, func assertDualJSONWebKeyDoc(t *testing.T, doc *did.Doc, didKey, didKeyID string, pubKeyCurve elliptic.Curve, pubKeyX, pubKeyY *big.Int, agreementKeyID string, keyAgreementCurve elliptic.Curve, keyAgreementX, keyAgreementY *big.Int) { + var context string + switch ctx := doc.Context.(type) { + case string: + context = ctx + case []string: + context = ctx[0] + case []interface{}: + var ok bool + context, ok = ctx[0].(string) + require.True(t, ok) + } + // validate @context - require.Equal(t, schemaDIDV1, doc.Context[0]) + require.Equal(t, schemaDIDV1, context) // validate id require.Equal(t, didKey, doc.ID)