Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

feat: support multi-format VC in presentation exchange #3347

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 127 additions & 25 deletions pkg/doc/presexch/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/xeipuuv/gojsonschema"

"github.com/hyperledger/aries-framework-go/pkg/common/log"
"github.com/hyperledger/aries-framework-go/pkg/doc/jwt"
"github.com/hyperledger/aries-framework-go/pkg/doc/verifiable"
)

Expand All @@ -41,6 +42,19 @@ const (
tmpEnding = "tmp_unique_id_"

credentialSchema = "credentialSchema"

// FormatJWT presentation exchange format.
FormatJWT = "jwt"
// FormatJWTVC presentation exchange format.
FormatJWTVC = "jwt_vc"
// FormatJWTVP presentation exchange format.
FormatJWTVP = "jwt_vp"
// FormatLDP presentation exchange format.
FormatLDP = "ldp"
// FormatLDPVC presentation exchange format.
FormatLDPVC = "ldp_vc"
// FormatLDPVP presentation exchange format.
FormatLDPVP = "ldp_vp"
)

var errPathNotApplicable = errors.New("path not applicable")
Expand Down Expand Up @@ -327,12 +341,12 @@ func (pd *PresentationDefinition) CreateVP(credentials []*verifiable.Credential,
return nil, err
}

result, err := pd.applyRequirement(req, credentials, documentLoader, opts...)
format, result, err := pd.applyRequirement(req, credentials, documentLoader, opts...)
if err != nil {
return nil, err
}

applicableCredentials, descriptors := merge(result)
applicableCredentials, descriptors := merge(format, result)

vp, err := verifiable.NewPresentation(verifiable.WithCredentials(applicableCredentials...))
if err != nil {
Expand All @@ -358,8 +372,13 @@ var ErrNoCredentials = errors.New("credentials do not satisfy requirements")

// nolint: gocyclo,funlen,gocognit
func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*verifiable.Credential,
documentLoader ld.DocumentLoader, opts ...verifiable.CredentialOpt) (map[string][]*verifiable.Credential, error) {
documentLoader ld.DocumentLoader,
opts ...verifiable.CredentialOpt) (string, map[string][]*verifiable.Credential, error) {
result := make(map[string][]*verifiable.Credential)
// assume LDPVP format if pd.Format is not set.
// Usually pd.Format will be set when creds include a non-empty Proofs field since they represent the designated
// format.
vpFormat := FormatLDPVP

for _, descriptor := range req.InputDescriptors {
format := pd.Format
Expand All @@ -371,11 +390,11 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve

filtered, err := frameCreds(pd.Frame, filtered, opts...)
if err != nil {
return nil, err
return "", nil, err
}

if format != nil {
filtered = filterFormat(format, filtered)
vpFormat, filtered = filterFormat(format, filtered)
}

// Validate schema only for v1
Expand All @@ -385,7 +404,7 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve

filtered, err = filterConstraints(descriptor.Constraints, filtered, opts...)
if err != nil {
return nil, err
return "", nil, err
}

if len(filtered) != 0 {
Expand All @@ -395,10 +414,10 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve

if len(req.InputDescriptors) != 0 {
if req.isLenApplicable(len(result)) {
return result, nil
return vpFormat, result, nil
}

return nil, ErrNoCredentials
return "", nil, ErrNoCredentials
}

var nestedResult []map[string][]*verifiable.Credential
Expand All @@ -407,13 +426,13 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve
set := map[string]map[string]string{}

for _, r := range req.Nested {
res, err := pd.applyRequirement(r, creds, documentLoader, opts...)
vpFmt, res, err := pd.applyRequirement(r, creds, documentLoader, opts...)
if errors.Is(err, ErrNoCredentials) {
continue
}

if err != nil {
return nil, err
return "", nil, err
}

for desc, credentials := range res {
Expand All @@ -428,6 +447,7 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve

if len(res) != 0 {
nestedResult = append(nestedResult, res)
vpFormat = vpFmt
}
}

Expand All @@ -441,7 +461,7 @@ func (pd *PresentationDefinition) applyRequirement(req *requirement, creds []*ve
}
}

return mergeNestedResult(nestedResult, exclude), nil
return vpFormat, mergeNestedResult(nestedResult, exclude), nil
}

func mergeNestedResult(nr []map[string][]*verifiable.Credential,
Expand Down Expand Up @@ -914,7 +934,7 @@ func getPath(keys []interface{}, set map[string]int) [2]string {
return [...]string{strings.Join(newPath, "."), strings.Join(originalPath, ".")}
}

func merge(setOfCredentials map[string][]*verifiable.Credential) ([]*verifiable.Credential, []*InputDescriptorMapping) {
func merge(format string, setOfCredentials map[string][]*verifiable.Credential) ([]*verifiable.Credential, []*InputDescriptorMapping) { //nolint:lll
setOfCreds := make(map[string]int)
setOfDescriptors := make(map[string]struct{})

Expand Down Expand Up @@ -942,9 +962,8 @@ func merge(setOfCredentials map[string][]*verifiable.Credential) ([]*verifiable.

if _, ok := setOfDescriptors[fmt.Sprintf("%s-%s", credential.ID, credential.ID)]; !ok {
descriptors = append(descriptors, &InputDescriptorMapping{
ID: descriptorID,
// TODO: what format should be here?
Format: "ldp_vp",
ID: descriptorID,
Format: format,
Path: fmt.Sprintf("$.verifiableCredential[%d]", setOfCreds[credential.ID]),
})
}
Expand All @@ -962,22 +981,105 @@ func (a byID) Len() int { return len(a) }
func (a byID) Less(i, j int) bool { return a[i].ID < a[j].ID }
func (a byID) Swap(i, j int) { a[i], a[j] = a[j], a[i] }

func filterFormat(format *Format, credentials []*verifiable.Credential) []*verifiable.Credential {
var result []*verifiable.Credential

if format.LdpVP == nil {
return result
}
//nolint:funlen,gocyclo
func filterFormat(format *Format, credentials []*verifiable.Credential) (string, []*verifiable.Credential) {
var ldpCreds, ldpvcCreds, ldpvpCreds, jwtCreds, jwtvcCreds, jwtvpCreds []*verifiable.Credential

for _, credential := range credentials {
for _, proofType := range format.LdpVP.ProofType {
if hasProofWithType(credential, proofType) {
result = append(result, credential)
if credByProof(credential, format.Ldp) {
ldpCreds = append(ldpCreds, credential)
}

if credByProof(credential, format.LdpVC) {
ldpvcCreds = append(ldpvcCreds, credential)
}

if credByProof(credential, format.LdpVP) {
ldpvpCreds = append(ldpvpCreds, credential)
}

var (
alg string
hasAlg bool
)

if credential.JWT != "" {
pJWT, err := jwt.Parse(credential.JWT)
if err != nil {
logger.Warnf("unmarshal credential error: %w", err)

continue
}

alg, hasAlg = pJWT.Headers.Algorithm()
}

if hasAlg && algMatch(alg, format.Jwt) {
jwtCreds = append(jwtCreds, credential)
}

if hasAlg && algMatch(alg, format.JwtVC) {
jwtvcCreds = append(jwtvcCreds, credential)
}

if hasAlg && algMatch(alg, format.JwtVP) {
jwtvpCreds = append(jwtvpCreds, credential)
}
}

return result
if len(ldpCreds) > 0 {
return FormatLDP, ldpCreds
}

if len(ldpvcCreds) > 0 {
return FormatLDPVC, ldpvcCreds
}

if len(ldpvpCreds) > 0 {
return FormatLDPVP, ldpvpCreds
}

if len(jwtCreds) > 0 {
return FormatJWT, jwtCreds
}

if len(jwtvcCreds) > 0 {
return FormatJWTVC, jwtvcCreds
}

if len(jwtvpCreds) > 0 {
return FormatJWTVP, jwtvpCreds
}

return "", nil
}

func algMatch(credAlg string, jwtType *JwtType) bool {
if jwtType == nil {
return false
}

for _, b := range jwtType.Alg {
if strings.EqualFold(credAlg, b) {
return true
}
}

return false
}

func credByProof(c *verifiable.Credential, ldp *LdpType) bool {
if ldp == nil {
return false
}

for _, proofType := range ldp.ProofType {
if hasProofWithType(c, proofType) {
return true
}
}

return false
}

// nolint: gocyclo
Expand Down
Loading