diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 1ea8c40ea1..79f4b3f214 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -5,11 +5,14 @@ import ( "crypto/ed25519" "encoding/json" "fmt" + "net/http" + "net/http/httptest" "strings" "sync" "testing" "time" + "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -17,7 +20,10 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" @@ -362,3 +368,126 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { } } } + +func TestNotaryServer(t *testing.T) { + testCases := []struct { + name string + httpBody string + pubKeyRequest *gomatrixserverlib.PublicKeyNotaryLookupRequest + validateFunc func(t *testing.T, response util.JSONResponse) + }{ + { + name: "empty httpBody", + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusBadRequest, resp.Code) + nk, ok := resp.JSON.(spec.MatrixError) + assert.True(t, ok) + assert.Equal(t, spec.ErrorBadJSON, nk.ErrCode) + }, + }, + { + name: "valid but empty httpBody", + httpBody: "{}", + validateFunc: func(t *testing.T, resp util.JSONResponse) { + want := util.JSONResponse{ + Code: http.StatusOK, + JSON: routing.NotaryKeysResponse{ServerKeys: []json.RawMessage{}}, + } + assert.Equal(t, want, resp) + }, + }, + { + name: "request all keys using an empty criteria", + httpBody: `{"server_keys":{"servera":{}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str) + assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists()) + }, + }, + { + name: "request all keys using null as the criteria", + httpBody: `{"server_keys":{"servera":null}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str) + assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists()) + }, + }, + { + name: "request specific key", + httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str) + assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists()) + }, + }, + { + name: "request multiple servers", + httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}},"serverb":{"ed25519:someID":{}}}}`, + validateFunc: func(t *testing.T, resp util.JSONResponse) { + assert.Equal(t, http.StatusOK, resp.Code) + nk, ok := resp.JSON.(routing.NotaryKeysResponse) + assert.True(t, ok) + wantServers := map[string]struct{}{ + "servera": {}, + "serverb": {}, + } + for _, js := range nk.ServerKeys { + serverName := gjson.GetBytes(js, "server_name").Str + _, ok = wantServers[serverName] + assert.True(t, ok, "unexpected servername: %s", serverName) + delete(wantServers, serverName) + assert.True(t, gjson.GetBytes(js, "verify_keys.ed25519:someID").Exists()) + } + if len(wantServers) > 0 { + t.Fatalf("expected response to also contain: %#v", wantServers) + } + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + fc := &fedClient{ + keys: map[spec.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + }{ + "servera": { + key: test.PrivateKeyA, + keyID: "ed25519:someID", + }, + "serverb": { + key: test.PrivateKeyB, + keyID: "ed25519:someID", + }, + }, + } + + fedAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, fc, nil, caches, nil, true) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.httpBody)) + req.Host = string(cfg.Global.ServerName) + + resp := routing.NotaryKeys(req, &cfg.FederationAPI, fedAPI, tc.pubKeyRequest) + // assert that we received the expected response + tc.validateFunc(t, resp) + }) + } + + }) +} diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index e53f19ff8c..21e77c48d9 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -43,6 +43,15 @@ func (a *FederationInternalAPI) fetchServerKeysFromCache( ctx context.Context, req *api.QueryServerKeysRequest, ) ([]gomatrixserverlib.ServerKeys, error) { var results []gomatrixserverlib.ServerKeys + + // We got a request for _all_ server keys, return them. + if len(req.KeyIDToCriteria) == 0 { + serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{}) + if len(serverKeysResponses) == 0 { + return nil, fmt.Errorf("failed to find server key response for server %s", req.ServerName) + } + return serverKeysResponses, nil + } for keyID, criteria := range req.KeyIDToCriteria { serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID}) if len(serverKeysResponses) == 0 { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 3d8ff2deab..38a88e4b16 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -197,6 +197,10 @@ func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrix return &keys, err } +type NotaryKeysResponse struct { + ServerKeys []json.RawMessage `json:"server_keys"` +} + func NotaryKeys( httpReq *http.Request, cfg *config.FederationAPI, fsAPI federationAPI.FederationInternalAPI, @@ -217,10 +221,9 @@ func NotaryKeys( } } - var response struct { - ServerKeys []json.RawMessage `json:"server_keys"` + response := NotaryKeysResponse{ + ServerKeys: []json.RawMessage{}, } - response.ServerKeys = []json.RawMessage{} for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys