diff --git a/pkg/client/mediator/client.go b/pkg/client/mediator/client.go index f17d574a2..4f3eec55b 100644 --- a/pkg/client/mediator/client.go +++ b/pkg/client/mediator/client.go @@ -39,7 +39,7 @@ type protocolService interface { Unregister(connID string) error // GetConnections returns router`s connections. - GetConnections() ([]string, error) + GetConnections(...mediator.ConnectionOption) ([]string, error) // Config returns the router's configuration. Config(connID string) (*mediator.Config, error) @@ -91,8 +91,8 @@ func (c *Client) Unregister(connID string) error { } // GetConnections returns router`s connections. -func (c *Client) GetConnections() ([]string, error) { - connections, err := c.routeSvc.GetConnections() +func (c *Client) GetConnections(options ...ConnectionOption) ([]string, error) { + connections, err := c.routeSvc.GetConnections(options...) if err != nil { return nil, fmt.Errorf("get router connections: %w", err) } diff --git a/pkg/client/mediator/models.go b/pkg/client/mediator/models.go index b6b8dbdc2..adf23ecf1 100644 --- a/pkg/client/mediator/models.go +++ b/pkg/client/mediator/models.go @@ -22,6 +22,9 @@ const ( // Request is the route-request message of this protocol. type Request = mediator.Request +// ConnectionOption option for Client.GetConnections. +type ConnectionOption = mediator.ConnectionOption + // NewRequest creates a new request. func NewRequest() *Request { return &Request{ diff --git a/pkg/controller/command/mediator/command.go b/pkg/controller/command/mediator/command.go index cf69599df..ff29e8ce0 100644 --- a/pkg/controller/command/mediator/command.go +++ b/pkg/controller/command/mediator/command.go @@ -19,6 +19,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/controller/command" "github.com/hyperledger/aries-framework-go/pkg/controller/internal/cmdutil" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + mediatorSvc "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" "github.com/hyperledger/aries-framework-go/pkg/internal/logutil" "github.com/hyperledger/aries-framework-go/pkg/kms" ) @@ -215,7 +216,40 @@ func (o *Command) Unregister(rw io.Writer, req io.Reader) command.Error { // Connections returns the connections of the router. func (o *Command) Connections(rw io.Writer, req io.Reader) command.Error { - connections, err := o.routeClient.GetConnections() + var request ConnectionsRequest + + if req != nil { + reqData, err := io.ReadAll(req) + if err != nil { + logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error()) + return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("read request : %w", err)) + } + + if len(reqData) > 0 { + err = json.Unmarshal(reqData, &request) + if err != nil { + logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error()) + return command.NewValidationError(InvalidRequestErrorCode, fmt.Errorf("decode request : %w", err)) + } + } + } + + opts := []mediator.ConnectionOption{} + + switch { + case request.DIDCommV1Only && request.DIDCommV2Only: + errMsg := "can't request didcomm v1 only at the same time as didcomm v2 only" + + logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, errMsg) + + return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("%s", errMsg)) + case request.DIDCommV2Only: + opts = append(opts, mediatorSvc.ConnectionByVersion(service.V2)) + case request.DIDCommV1Only: + opts = append(opts, mediatorSvc.ConnectionByVersion(service.V1)) + } + + connections, err := o.routeClient.GetConnections(opts...) if err != nil { logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, err.Error()) return command.NewExecuteError(GetConnectionsErrorCode, err) diff --git a/pkg/controller/command/mediator/command_test.go b/pkg/controller/command/mediator/command_test.go index f6cfa1575..b45cfc325 100644 --- a/pkg/controller/command/mediator/command_test.go +++ b/pkg/controller/command/mediator/command_test.go @@ -207,14 +207,96 @@ func TestCommand_Connections(t *testing.T) { require.NoError(t, err) require.NotNil(t, cmd) + testcases := []struct { + name string + input string + }{ + { + name: "no filters", + input: `{}`, + }, + { + name: "didcomm v1 only", + input: `{"didcomm_v1": true}`, + }, + { + name: "didcomm v2 only", + input: `{"didcomm_v2": true}`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + var b bytes.Buffer + err = cmd.Connections(&b, bytes.NewBufferString(tc.input)) + require.NoError(t, err) + + response := ConnectionsResponse{} + err = json.NewDecoder(&b).Decode(&response) + require.NoError(t, err) + require.Equal(t, routerConnectionID, response.Connections[0]) + }) + } + }) + + t.Run("test get connection - read request error", func(t *testing.T) { + cmd, err := New( + &mockprovider.Provider{ + ServiceMap: map[string]interface{}{ + messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{}, + mediator.Coordination: &mockroute.MockMediatorSvc{}, + oobsvc.Name: &mockoob.MockOobService{}, + }, + }, + false, + ) + require.NoError(t, err) + require.NotNil(t, cmd) + var b bytes.Buffer - err = cmd.Connections(&b, nil) + err = cmd.Connections(&b, &errReader{err: fmt.Errorf("expected error")}) + require.Error(t, err) + require.Contains(t, err.Error(), "read request") + }) + + t.Run("test get connection - decode request error", func(t *testing.T) { + cmd, err := New( + &mockprovider.Provider{ + ServiceMap: map[string]interface{}{ + messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{}, + mediator.Coordination: &mockroute.MockMediatorSvc{}, + oobsvc.Name: &mockoob.MockOobService{}, + }, + }, + false, + ) require.NoError(t, err) + require.NotNil(t, cmd) - response := ConnectionsResponse{} - err = json.NewDecoder(&b).Decode(&response) + var b bytes.Buffer + err = cmd.Connections(&b, bytes.NewBufferString("{")) + require.Error(t, err) + require.Contains(t, err.Error(), "decode request") + }) + + t.Run("test get connection - invalid filter options error", func(t *testing.T) { + cmd, err := New( + &mockprovider.Provider{ + ServiceMap: map[string]interface{}{ + messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{}, + mediator.Coordination: &mockroute.MockMediatorSvc{}, + oobsvc.Name: &mockoob.MockOobService{}, + }, + }, + false, + ) require.NoError(t, err) - require.Equal(t, routerConnectionID, response.Connections[0]) + require.NotNil(t, cmd) + + var b bytes.Buffer + err = cmd.Connections(&b, bytes.NewBufferString(`{"didcomm_v1": true, "didcomm_v2": true}`)) + require.Error(t, err) + require.Contains(t, err.Error(), "at the same time") }) t.Run("test get connection - error", func(t *testing.T) { @@ -234,7 +316,7 @@ func TestCommand_Connections(t *testing.T) { require.NotNil(t, cmd) var b bytes.Buffer - err = cmd.Connections(&b, nil) + err = cmd.Connections(&b, bytes.NewBufferString("{}")) require.Error(t, err) require.Contains(t, err.Error(), "get router connections") }) @@ -533,3 +615,11 @@ func newMockProvider(serviceMap map[string]interface{}) *mockprovider.Provider { ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(), } } + +type errReader struct { + err error +} + +func (e *errReader) Read([]byte) (int, error) { + return 0, e.err +} diff --git a/pkg/controller/command/mediator/models.go b/pkg/controller/command/mediator/models.go index 43256bc90..40148a916 100644 --- a/pkg/controller/command/mediator/models.go +++ b/pkg/controller/command/mediator/models.go @@ -16,6 +16,12 @@ type RegisterRoute struct { ConnectionID string `json:"connectionID"` } +// ConnectionsRequest contains parameters for filtering when requesting router connections. +type ConnectionsRequest struct { + DIDCommV1Only bool `json:"didcomm_v1"` + DIDCommV2Only bool `json:"didcomm_v2"` +} + // ConnectionsResponse is response for router`s connections. type ConnectionsResponse struct { Connections []string `json:"connections"` diff --git a/pkg/didcomm/protocol/mediator/api.go b/pkg/didcomm/protocol/mediator/api.go index f318e6c5e..0f2ad9663 100644 --- a/pkg/didcomm/protocol/mediator/api.go +++ b/pkg/didcomm/protocol/mediator/api.go @@ -15,5 +15,5 @@ type ProtocolService interface { Config(connID string) (*Config, error) // GetConnections returns all router connections - GetConnections() ([]string, error) + GetConnections(options ...ConnectionOption) ([]string, error) } diff --git a/pkg/didcomm/protocol/mediator/service.go b/pkg/didcomm/protocol/mediator/service.go index 4e1a8d235..d737d59c8 100644 --- a/pkg/didcomm/protocol/mediator/service.go +++ b/pkg/didcomm/protocol/mediator/service.go @@ -126,6 +126,11 @@ type callback struct { err error } +type routerConnectionEntry struct { + ConnectionID string `json:"connectionID"` + DIDCommVersion service.Version `json:"didcomm_version,omitempty"` +} + type connections interface { GetConnectionIDByDIDs(string, string) (string, error) GetConnectionRecord(string) (*connection.Record, error) @@ -646,7 +651,7 @@ func (s *Service) doRegistration(record *connection.Record, req *Request, timeou logger.Debugf("saved router config from inbound grant: %+v", grant) // save the connectionID of the router - return s.saveRouterConnectionID(record.ConnectionID) + return s.saveRouterConnectionID(record.ConnectionID, record.DIDCommVersion) } func (s *Service) getGrant(id string, timeout time.Duration) (*Grant, error) { @@ -700,7 +705,13 @@ func (s *Service) Unregister(connID string) error { } // GetConnections returns the connections of the router. -func (s *Service) GetConnections() ([]string, error) { +func (s *Service) GetConnections(options ...ConnectionOption) ([]string, error) { + opts := &getConnectionOpts{} + + for _, option := range options { + option(opts) + } + records, err := s.routeStore.Query(routeConnIDDataKey) if err != nil { return nil, fmt.Errorf("failed to query route store: %w", err) @@ -721,7 +732,16 @@ func (s *Service) GetConnections() ([]string, error) { return nil, fmt.Errorf("failed to get value from records: %w", err) } - conns = append(conns, string(value)) + data := &routerConnectionEntry{} + + err = json.Unmarshal(value, data) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal router connection entry: %w", err) + } + + if opts.version == "" || opts.version == data.DIDCommVersion { + conns = append(conns, data.ConnectionID) + } more, err = records.Next() if err != nil { @@ -838,8 +858,18 @@ func (s *Service) deleteRouterConnectionID(connID string) error { return s.routeStore.Delete(fmt.Sprintf(routeConnIDDataKey, connID)) } -func (s *Service) saveRouterConnectionID(connID string) error { - return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), []byte(connID), storage.Tag{Name: routeConnIDDataKey}) +func (s *Service) saveRouterConnectionID(connID string, didcommVersion service.Version) error { + data := &routerConnectionEntry{ + ConnectionID: connID, + DIDCommVersion: didcommVersion, + } + + dataBytes, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshalling router connection ID data: %w", err) + } + + return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), dataBytes, storage.Tag{Name: routeConnIDDataKey}) } type config struct { @@ -918,3 +948,17 @@ func parseClientOpts(options ...ClientOption) *ClientOptions { return opts } + +type getConnectionOpts struct { + version service.Version +} + +// ConnectionOption option for Service.GetConnections. +type ConnectionOption func(opts *getConnectionOpts) + +// ConnectionByVersion filter for mediator connections of the given DIDComm version. +func ConnectionByVersion(v service.Version) ConnectionOption { + return func(opts *getConnectionOpts) { + opts.version = v + } +} diff --git a/pkg/didcomm/protocol/mediator/service_test.go b/pkg/didcomm/protocol/mediator/service_test.go index 9676e0471..753fabe5f 100644 --- a/pkg/didcomm/protocol/mediator/service_test.go +++ b/pkg/didcomm/protocol/mediator/service_test.go @@ -31,6 +31,7 @@ import ( mockstore "github.com/hyperledger/aries-framework-go/pkg/mock/storage" mockvdr "github.com/hyperledger/aries-framework-go/pkg/mock/vdr" "github.com/hyperledger/aries-framework-go/pkg/store/connection" + "github.com/hyperledger/aries-framework-go/spi/storage" ) const ( @@ -1096,7 +1097,7 @@ func TestUnregister(t *testing.T) { ) require.NoError(t, err) - s[fmt.Sprintf(routeConnIDDataKey, connID)] = mockstore.DBEntry{Value: []byte("conn-abc-xyz")} + s[fmt.Sprintf(routeConnIDDataKey, connID)] = mockstore.DBEntry{Value: []byte("{\"connectionID\":\"conn-abc-xyz\"}")} err = svc.Unregister(connID) require.NoError(t, err) @@ -1179,7 +1180,7 @@ func TestKeylistUpdate(t *testing.T) { require.NoError(t, err) // save router connID - require.NoError(t, svc.saveRouterConnectionID("conn")) + require.NoError(t, svc.saveRouterConnectionID("conn", "")) // save connections connRec := &connection.Record{ @@ -1245,7 +1246,7 @@ func TestKeylistUpdate(t *testing.T) { require.Contains(t, err.Error(), "router not registered") // save router connID - require.NoError(t, svc.saveRouterConnectionID("conn")) + require.NoError(t, svc.saveRouterConnectionID("conn", "")) // no connections saved err = svc.AddKey("conn", recKey) @@ -1298,7 +1299,7 @@ func TestKeylistUpdate(t *testing.T) { connBytes, err := json.Marshal(connRec) require.NoError(t, err) s["conn_conn2"] = mockstore.DBEntry{Value: connBytes} - require.NoError(t, svc.saveRouterConnectionID("conn2")) + require.NoError(t, svc.saveRouterConnectionID("conn2", "")) err = svc.AddKey("conn2", "recKey") require.Error(t, err) @@ -1342,7 +1343,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) require.NoError(t, svc.saveRouterConfig("connID-123", &config{ RouterEndpoint: ENDPOINT, RoutingKeys: routingKeys, @@ -1386,7 +1387,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) conf, err := svc.Config("connID-123") require.Error(t, err) @@ -1409,7 +1410,7 @@ func TestConfig(t *testing.T) { const conn = "connID-123" - require.NoError(t, svc.saveRouterConnectionID(conn)) + require.NoError(t, svc.saveRouterConnectionID(conn, "")) require.NoError(t, svc.routeStore.Put(fmt.Sprintf(routeConfigDataKey, conn), []byte("invalid data"))) conf, err := svc.Config(conn) @@ -1433,7 +1434,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) require.NoError(t, svc.routeStore.Put(routeConfigDataKey, []byte("invalid data"))) conf, err := svc.Config("connID-123") @@ -1458,7 +1459,7 @@ func TestGetConnections(t *testing.T) { ) require.NoError(t, err) - err = svc.saveRouterConnectionID(routerConnectionID) + err = svc.saveRouterConnectionID(routerConnectionID, "") require.NoError(t, err) connID, err := svc.GetConnections() @@ -1482,6 +1483,68 @@ func TestGetConnections(t *testing.T) { require.NoError(t, err) require.Empty(t, connID) }) + + t.Run("test get connection - filter by didcomm version", func(t *testing.T) { + svc, err := New( + &mockprovider.Provider{ + ServiceMap: map[string]interface{}{ + messagepickup.MessagePickup: &mockmessagep.MockMessagePickupSvc{}, + }, + StorageProviderValue: mem.NewProvider(), + ProtocolStateStorageProviderValue: mem.NewProvider(), + }, + ) + require.NoError(t, err) + + const ( + connID1 = "conn-id-1" + connID2 = "conn-id-2" + ) + + err = svc.saveRouterConnectionID(connID1, service.V1) + require.NoError(t, err) + + err = svc.saveRouterConnectionID(connID2, service.V2) + require.NoError(t, err) + + connIDs, err := svc.GetConnections() + require.NoError(t, err) + require.Len(t, connIDs, 2) + + connIDs, err = svc.GetConnections(ConnectionByVersion(service.V1)) + require.NoError(t, err) + require.Len(t, connIDs, 1) + require.Equal(t, connID1, connIDs[0]) + + connIDs, err = svc.GetConnections(ConnectionByVersion(service.V2)) + require.NoError(t, err) + require.Len(t, connIDs, 1) + require.Equal(t, connID2, connIDs[0]) + }) + + t.Run("test get connection - fail to parse connection entry", func(t *testing.T) { + svc, err := New( + &mockprovider.Provider{ + ServiceMap: map[string]interface{}{ + messagepickup.MessagePickup: &mockmessagep.MockMessagePickupSvc{}, + }, + StorageProviderValue: mem.NewProvider(), + ProtocolStateStorageProviderValue: mem.NewProvider(), + }, + ) + require.NoError(t, err) + + err = svc.routeStore.Put( + fmt.Sprintf(routeConnIDDataKey, routerConnectionID), + []byte("foo"), + storage.Tag{Name: routeConnIDDataKey}, + ) + require.NoError(t, err) + + _, err = svc.GetConnections() + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal router connection entry") + }) } func generateRequestMsgPayload(t *testing.T, id string) service.DIDCommMsg { diff --git a/pkg/didcomm/protocol/mediator/util_test.go b/pkg/didcomm/protocol/mediator/util_test.go index 4522abe68..928875956 100644 --- a/pkg/didcomm/protocol/mediator/util_test.go +++ b/pkg/didcomm/protocol/mediator/util_test.go @@ -87,7 +87,7 @@ func (m *mockRouteSvc) AddKey(connID, recKey string) error { } // AddKey adds agents recKey to the router. -func (m *mockRouteSvc) GetConnections() ([]string, error) { +func (m *mockRouteSvc) GetConnections(...ConnectionOption) ([]string, error) { return m.Connections, m.ConnectionsErr } diff --git a/pkg/mock/didcomm/protocol/mediator/mock_mediator.go b/pkg/mock/didcomm/protocol/mediator/mock_mediator.go index 5a290d849..d133ca404 100644 --- a/pkg/mock/didcomm/protocol/mediator/mock_mediator.go +++ b/pkg/mock/didcomm/protocol/mediator/mock_mediator.go @@ -115,7 +115,7 @@ func (m *MockMediatorSvc) Config(connID string) (*mediator.Config, error) { } // GetConnections returns router`s connections. -func (m *MockMediatorSvc) GetConnections() ([]string, error) { +func (m *MockMediatorSvc) GetConnections(...mediator.ConnectionOption) ([]string, error) { if m.GetConnectionsErr != nil { return nil, m.GetConnectionsErr }