Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

Commit

Permalink
feat: allow mediator GetConnections APIs to filter by didcomm version.
Browse files Browse the repository at this point in the history
Signed-off-by: Filip Burlacu <filip.burlacu@securekey.com>
  • Loading branch information
Filip Burlacu committed Aug 9, 2022
1 parent adc27a6 commit d8ea6f7
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 26 deletions.
6 changes: 3 additions & 3 deletions pkg/client/mediator/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type protocolService interface {
Unregister(connID string) error

// GetConnections returns router`s connections.
GetConnections() ([]string, error)
GetConnections(...mediator.ConnectionOption) ([]string, error)

// Config returns the router's configuration.
Config(connID string) (*mediator.Config, error)
Expand Down Expand Up @@ -91,8 +91,8 @@ func (c *Client) Unregister(connID string) error {
}

// GetConnections returns router`s connections.
func (c *Client) GetConnections() ([]string, error) {
connections, err := c.routeSvc.GetConnections()
func (c *Client) GetConnections(options ...ConnectionOption) ([]string, error) {
connections, err := c.routeSvc.GetConnections(options...)
if err != nil {
return nil, fmt.Errorf("get router connections: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/client/mediator/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ const (
// Request is the route-request message of this protocol.
type Request = mediator.Request

// ConnectionOption option for Client.GetConnections.
type ConnectionOption = mediator.ConnectionOption

// NewRequest creates a new request.
func NewRequest() *Request {
return &Request{
Expand Down
36 changes: 35 additions & 1 deletion pkg/controller/command/mediator/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/hyperledger/aries-framework-go/pkg/controller/command"
"github.com/hyperledger/aries-framework-go/pkg/controller/internal/cmdutil"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
mediatorSvc "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator"
"github.com/hyperledger/aries-framework-go/pkg/internal/logutil"
"github.com/hyperledger/aries-framework-go/pkg/kms"
)
Expand Down Expand Up @@ -215,7 +216,40 @@ func (o *Command) Unregister(rw io.Writer, req io.Reader) command.Error {

// Connections returns the connections of the router.
func (o *Command) Connections(rw io.Writer, req io.Reader) command.Error {
connections, err := o.routeClient.GetConnections()
var request ConnectionsRequest

if req != nil {
reqData, err := io.ReadAll(req)
if err != nil {
logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("read request : %w", err))
}

if len(reqData) > 0 {
err = json.Unmarshal(reqData, &request)
if err != nil {
logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewValidationError(InvalidRequestErrorCode, fmt.Errorf("decode request : %w", err))
}
}
}

opts := []mediator.ConnectionOption{}

switch {
case request.DIDCommV1Only && request.DIDCommV2Only:
errMsg := "can't request didcomm v1 only at the same time as didcomm v2 only"

logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, errMsg)

return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("%s", errMsg))
case request.DIDCommV2Only:
opts = append(opts, mediatorSvc.ConnectionByVersion(service.V2))
case request.DIDCommV1Only:
opts = append(opts, mediatorSvc.ConnectionByVersion(service.V1))
}

connections, err := o.routeClient.GetConnections(opts...)
if err != nil {
logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewExecuteError(GetConnectionsErrorCode, err)
Expand Down
100 changes: 95 additions & 5 deletions pkg/controller/command/mediator/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,96 @@ func TestCommand_Connections(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cmd)

testcases := []struct {
name string
input string
}{
{
name: "no filters",
input: `{}`,
},
{
name: "didcomm v1 only",
input: `{"didcomm_v1": true}`,
},
{
name: "didcomm v2 only",
input: `{"didcomm_v2": true}`,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString(tc.input))
require.NoError(t, err)

response := ConnectionsResponse{}
err = json.NewDecoder(&b).Decode(&response)
require.NoError(t, err)
require.Equal(t, routerConnectionID, response.Connections[0])
})
}
})

t.Run("test get connection - read request error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, nil)
err = cmd.Connections(&b, &errReader{err: fmt.Errorf("expected error")})
require.Error(t, err)
require.Contains(t, err.Error(), "read request")
})

t.Run("test get connection - decode request error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.NotNil(t, cmd)

response := ConnectionsResponse{}
err = json.NewDecoder(&b).Decode(&response)
var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString("{"))
require.Error(t, err)
require.Contains(t, err.Error(), "decode request")
})

t.Run("test get connection - invalid filter options error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.Equal(t, routerConnectionID, response.Connections[0])
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString(`{"didcomm_v1": true, "didcomm_v2": true}`))
require.Error(t, err)
require.Contains(t, err.Error(), "at the same time")
})

t.Run("test get connection - error", func(t *testing.T) {
Expand All @@ -234,7 +316,7 @@ func TestCommand_Connections(t *testing.T) {
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, nil)
err = cmd.Connections(&b, bytes.NewBufferString("{}"))
require.Error(t, err)
require.Contains(t, err.Error(), "get router connections")
})
Expand Down Expand Up @@ -533,3 +615,11 @@ func newMockProvider(serviceMap map[string]interface{}) *mockprovider.Provider {
ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(),
}
}

type errReader struct {
err error
}

func (e *errReader) Read([]byte) (int, error) {
return 0, e.err
}
6 changes: 6 additions & 0 deletions pkg/controller/command/mediator/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ type RegisterRoute struct {
ConnectionID string `json:"connectionID"`
}

// ConnectionsRequest contains parameters for filtering when requesting router connections.
type ConnectionsRequest struct {
DIDCommV1Only bool `json:"didcomm_v1"`
DIDCommV2Only bool `json:"didcomm_v2"`
}

// ConnectionsResponse is response for router`s connections.
type ConnectionsResponse struct {
Connections []string `json:"connections"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/didcomm/protocol/mediator/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ type ProtocolService interface {
Config(connID string) (*Config, error)

// GetConnections returns all router connections
GetConnections() ([]string, error)
GetConnections(options ...ConnectionOption) ([]string, error)
}
54 changes: 49 additions & 5 deletions pkg/didcomm/protocol/mediator/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ type callback struct {
err error
}

type routerConnectionEntry struct {
ConnectionID string `json:"connectionID"`
DIDCommVersion service.Version `json:"didcomm_version,omitempty"`
}

type connections interface {
GetConnectionIDByDIDs(string, string) (string, error)
GetConnectionRecord(string) (*connection.Record, error)
Expand Down Expand Up @@ -646,7 +651,7 @@ func (s *Service) doRegistration(record *connection.Record, req *Request, timeou
logger.Debugf("saved router config from inbound grant: %+v", grant)

// save the connectionID of the router
return s.saveRouterConnectionID(record.ConnectionID)
return s.saveRouterConnectionID(record.ConnectionID, record.DIDCommVersion)
}

func (s *Service) getGrant(id string, timeout time.Duration) (*Grant, error) {
Expand Down Expand Up @@ -700,7 +705,13 @@ func (s *Service) Unregister(connID string) error {
}

// GetConnections returns the connections of the router.
func (s *Service) GetConnections() ([]string, error) {
func (s *Service) GetConnections(options ...ConnectionOption) ([]string, error) {
opts := &getConnectionOpts{}

for _, option := range options {
option(opts)
}

records, err := s.routeStore.Query(routeConnIDDataKey)
if err != nil {
return nil, fmt.Errorf("failed to query route store: %w", err)
Expand All @@ -721,7 +732,16 @@ func (s *Service) GetConnections() ([]string, error) {
return nil, fmt.Errorf("failed to get value from records: %w", err)
}

conns = append(conns, string(value))
data := &routerConnectionEntry{}

err = json.Unmarshal(value, data)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal router connection entry: %w", err)
}

if opts.version == "" || opts.version == data.DIDCommVersion {
conns = append(conns, data.ConnectionID)
}

more, err = records.Next()
if err != nil {
Expand Down Expand Up @@ -838,8 +858,18 @@ func (s *Service) deleteRouterConnectionID(connID string) error {
return s.routeStore.Delete(fmt.Sprintf(routeConnIDDataKey, connID))
}

func (s *Service) saveRouterConnectionID(connID string) error {
return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), []byte(connID), storage.Tag{Name: routeConnIDDataKey})
func (s *Service) saveRouterConnectionID(connID string, didcommVersion service.Version) error {
data := &routerConnectionEntry{
ConnectionID: connID,
DIDCommVersion: didcommVersion,
}

dataBytes, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshalling router connection ID data: %w", err)
}

return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), dataBytes, storage.Tag{Name: routeConnIDDataKey})
}

type config struct {
Expand Down Expand Up @@ -918,3 +948,17 @@ func parseClientOpts(options ...ClientOption) *ClientOptions {

return opts
}

type getConnectionOpts struct {
version service.Version
}

// ConnectionOption option for Service.GetConnections.
type ConnectionOption func(opts *getConnectionOpts)

// ConnectionByVersion filter for mediator connections of the given DIDComm version.
func ConnectionByVersion(v service.Version) ConnectionOption {
return func(opts *getConnectionOpts) {
opts.version = v
}
}
Loading

0 comments on commit d8ea6f7

Please sign in to comment.