From ba542fedcd364a812e2bac9f90bd2f87656efea2 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 21 Dec 2024 00:59:14 +0000 Subject: [PATCH 01/71] mas: added /auth_issuer endpoint --- clientapi/routing/routing.go | 9 +++++++++ setup/config/config_mscs.go | 26 ++++++++++++++++++++++++++ setup/mscs/mscs.go | 1 + 3 files changed, 36 insertions(+) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d72638ee0..373fa03fd 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -328,6 +328,15 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) + if m := mscCfg.MSC2965; mscCfg.Enabled("msc2965") && m != nil && m.Enabled { + unstableMux.Handle("/org.matrix.msc2965/auth_issuer", + httputil.MakeExternalAPI("auth_issuer", func(r *http.Request) util.JSONResponse { + return util.JSONResponse{Code: http.StatusOK, JSON: map[string]string{ + "issuer": m.Issuer, + }} + })) + } + if mscCfg.Enabled("msc2753") { v3mux.Handle("/peek/{roomIDOrAlias}", httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index ce491cd72..44fa5e72a 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -7,8 +7,11 @@ type MSCs struct { // 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444 // 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753 // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 + // 'msc2965': Delegate auth to an OIDC provider https://github.com/matrix-org/matrix-spec-proposals/pull/2965 MSCs []string `yaml:"mscs"` + MSC2965 *MSC2965 `yaml:"msc2965,omitempty"` + Database DatabaseOptions `yaml:"database,omitempty"` } @@ -34,4 +37,27 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } + if m := c.MSC2965; m != nil { + m.Verify(configErrs) + } +} + +type MSC2965 struct { + Enabled bool `yaml:"enabled"` + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + AdminToken string `yaml:"admin_token"` + AccountManagementURL string `yaml:"account_management_url"` +} + +func (m *MSC2965) Verify(configErrs *ConfigErrors) { + if !m.Enabled { + return + } + checkNotEmpty(configErrs, "mscs.msc2965.issuer", string(m.Issuer)) + checkNotEmpty(configErrs, "mscs.msc2965.client_id", string(m.ClientID)) + checkNotEmpty(configErrs, "mscs.msc2965.client_secret", string(m.ClientSecret)) + checkNotEmpty(configErrs, "mscs.msc2965.admin_token", string(m.AdminToken)) + checkNotEmpty(configErrs, "mscs.msc2965.account_management_url", string(m.AccountManagementURL)) } diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index fc360b5d8..91bc1a827 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -37,6 +37,7 @@ func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.R return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi + case "msc2965": // enabled inside clientapi default: logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc) } From 2c47959600fd3df21f122023c7757e3048af8999 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 22 Dec 2024 00:23:25 +0000 Subject: [PATCH 02/71] mas: added username_available endpoint --- clientapi/routing/admin.go | 21 +++++++++++++++++++++ clientapi/routing/routing.go | 7 ++++++- internal/httputil/httpapi.go | 32 ++++++++++++++++++++++++++++++++ setup/mscs/mscs.go | 2 +- 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 48e58209c..408811661 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -496,6 +496,27 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA } } +func AdminCheckUsernameAvailable( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + username := req.URL.Query().Get("username") + if username == "" { + return util.MessageResponse(http.StatusBadRequest, "Query parameter 'username' is missing or empty") + } + rq := userapi.QueryAccountAvailabilityRequest{Localpart: username, ServerName: cfg.Matrix.ServerName} + rs := userapi.QueryAccountAvailabilityResponse{} + if err := userAPI.QueryAccountAvailability(req.Context(), &rq, &rs); err != nil { + return util.ErrorResponse(err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]bool{"available": rs.Available}, + } +} + // GetEventReports returns reported events for a given user/room. func GetEventReports( req *http.Request, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 373fa03fd..0544c3c22 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -334,7 +334,12 @@ func Setup( return util.JSONResponse{Code: http.StatusOK, JSON: map[string]string{ "issuer": m.Issuer, }} - })) + })).Methods(http.MethodGet) + + synapseAdminRouter.Handle("/admin/v1/username_available", + httputil.MakeServiceAdminAPI("admin_username_available", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminCheckUsernameAvailable(r, userAPI, cfg) + })).Methods(http.MethodGet) } if mscCfg.Enabled("msc2753") { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index d32557679..5a332d6fa 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -136,6 +136,38 @@ func MakeAdminAPI( }) } +// MakeServiceAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be +// completed by a trusted service e.g. Matrix Auth Service. +func MakeServiceAdminAPI( + metricsName, serviceToken string, + f func(*http.Request) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + logger := util.GetLogger(req.Context()) + token, err := auth.ExtractAccessToken(req) + + if err != nil { + logger.Debugf("ExtractAccessToken %s -> HTTP %d", req.RemoteAddr, http.StatusUnauthorized) + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + if token != serviceToken { + logger.Debugf("Invalid service token '%s'", token) + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.UnknownToken(token), + } + } + // add the service addr to the logger + logger = logger.WithField("service_useragent", req.UserAgent()) + req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) + return f(req) + } + return MakeExternalAPI(metricsName, h) +} + // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // This is used for APIs that are called from the internet. func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 91bc1a827..6a220e62f 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -37,7 +37,7 @@ func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.R return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi - case "msc2965": // enabled inside clientapi + case "msc2965": // enabled inside clientapi default: logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc) } From e1dfe62b2072c0b9787febc0222b1fdd2c3311b5 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 22 Dec 2024 00:35:04 +0000 Subject: [PATCH 03/71] mas: rename msc2965 to msc3861 --- clientapi/routing/routing.go | 2 +- setup/config/config_mscs.go | 20 ++++++++++---------- setup/mscs/mscs.go | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 0544c3c22..73cfcfc3b 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -328,7 +328,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) - if m := mscCfg.MSC2965; mscCfg.Enabled("msc2965") && m != nil && m.Enabled { + if m := mscCfg.MSC3861; mscCfg.Enabled("msc3861") && m != nil && m.Enabled { unstableMux.Handle("/org.matrix.msc2965/auth_issuer", httputil.MakeExternalAPI("auth_issuer", func(r *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]string{ diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 44fa5e72a..d6a51b651 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -7,10 +7,10 @@ type MSCs struct { // 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444 // 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753 // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 - // 'msc2965': Delegate auth to an OIDC provider https://github.com/matrix-org/matrix-spec-proposals/pull/2965 + // 'msc3861': Delegate auth to an OIDC provider https://github.com/matrix-org/matrix-spec-proposals/pull/3861 MSCs []string `yaml:"mscs"` - MSC2965 *MSC2965 `yaml:"msc2965,omitempty"` + MSC3861 *MSC3861 `yaml:"msc3861,omitempty"` Database DatabaseOptions `yaml:"database,omitempty"` } @@ -37,12 +37,12 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } - if m := c.MSC2965; m != nil { + if m := c.MSC3861; m != nil { m.Verify(configErrs) } } -type MSC2965 struct { +type MSC3861 struct { Enabled bool `yaml:"enabled"` Issuer string `yaml:"issuer"` ClientID string `yaml:"client_id"` @@ -51,13 +51,13 @@ type MSC2965 struct { AccountManagementURL string `yaml:"account_management_url"` } -func (m *MSC2965) Verify(configErrs *ConfigErrors) { +func (m *MSC3861) Verify(configErrs *ConfigErrors) { if !m.Enabled { return } - checkNotEmpty(configErrs, "mscs.msc2965.issuer", string(m.Issuer)) - checkNotEmpty(configErrs, "mscs.msc2965.client_id", string(m.ClientID)) - checkNotEmpty(configErrs, "mscs.msc2965.client_secret", string(m.ClientSecret)) - checkNotEmpty(configErrs, "mscs.msc2965.admin_token", string(m.AdminToken)) - checkNotEmpty(configErrs, "mscs.msc2965.account_management_url", string(m.AccountManagementURL)) + checkNotEmpty(configErrs, "mscs.msc3861.issuer", string(m.Issuer)) + checkNotEmpty(configErrs, "mscs.msc3861.client_id", string(m.ClientID)) + checkNotEmpty(configErrs, "mscs.msc3861.client_secret", string(m.ClientSecret)) + checkNotEmpty(configErrs, "mscs.msc3861.admin_token", string(m.AdminToken)) + checkNotEmpty(configErrs, "mscs.msc3861.account_management_url", string(m.AccountManagementURL)) } diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 6a220e62f..8df539ba7 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -37,7 +37,7 @@ func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.R return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi - case "msc2965": // enabled inside clientapi + case "msc3861": // enabled inside clientapi default: logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc) } From 150be588f5b6719c0638cdae9ade5c3b2c967609 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 24 Dec 2024 03:06:26 +0000 Subject: [PATCH 04/71] mas: added localpart_external_ids table --- userapi/api/api.go | 8 ++ userapi/storage/interface.go | 7 ++ .../postgres/localpart_external_ids_table.go | 97 +++++++++++++++++++ userapi/storage/postgres/storage.go | 5 + userapi/storage/shared/storage.go | 13 +++ .../sqlite3/localpart_external_ids_table.go | 97 +++++++++++++++++++ userapi/storage/sqlite3/storage.go | 5 + userapi/storage/tables/interface.go | 10 ++ 8 files changed, 242 insertions(+) create mode 100644 userapi/storage/postgres/localpart_external_ids_table.go create mode 100644 userapi/storage/sqlite3/localpart_external_ids_table.go diff --git a/userapi/api/api.go b/userapi/api/api.go index 264821296..f0ef26bf1 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -471,6 +471,14 @@ type OpenIDTokenAttributes struct { ExpiresAtMS int64 } +// LocalpartExternalID represents a connection between Matrix account and OpenID Connect provider +type LocalpartExternalID struct { + Localpart string + ExternalID string + AuthProvider string + CreatedTs int64 +} + // UserInfo is for returning information about the user an OpenID token was issued for type UserInfo struct { Sub string // The Matrix user's ID who generated the token diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 2a46a7fd7..13d8c2013 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -134,6 +134,12 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } +type LocalpartExternalID interface { + CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error + GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) + DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error +} + type UserDatabase interface { Account AccountData @@ -147,6 +153,7 @@ type UserDatabase interface { Statistics ThreePID RegistrationTokens + LocalpartExternalID } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go new file mode 100644 index 000000000..9bc47dbf6 --- /dev/null +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -0,0 +1,97 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + CONSTRAINT userapi_localpart_external_ids_external_id_auth_provider_unique UNIQUE(external_id, auth_provider), + CONSTRAINT userapi_localpart_external_ids_localpart_external_id_auth_provider_unique UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertUserExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectUserExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteUserExternalIDSQL = "" + + "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertUserExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectUserExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteUserExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTs, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index c7fb9d29b..eff12a64a 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -97,6 +97,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + localpartExternalIDsTable, err := NewPostgresLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteLocalpartExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -123,6 +127,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties Notifications: notificationsTable, RegistrationTokens: registationTokensTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 44ace733e..aade4be1f 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -49,6 +49,7 @@ type Database struct { Notifications tables.NotificationTable Pushers tables.PusherTable Stats tables.StatsTable + LocalpartExternalIDs tables.LocalpartExternalIDsTable LoginTokenLifetime time.Duration ServerName spec.ServerName BcryptCost int @@ -870,6 +871,18 @@ func (d *Database) UpsertPusher( }) } +func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Insert(ctx, nil, localpart, externalID, authProvider) +} + +func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { + return d.LocalpartExternalIDs.Select(ctx, nil, externalID, authProvider) +} + +func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error { + return d.LocalpartExternalIDs.Delete(ctx, nil, externalID, authProvider) +} + // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( ctx context.Context, localpart string, serverName spec.ServerName, diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go new file mode 100644 index 000000000..30f1fc60e --- /dev/null +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -0,0 +1,97 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +const localpartExternalIDsSchema = ` +-- Stores data about connections between accounts and third-party auth providers +CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids ( + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- The external ID + external_id TEXT NOT NULL, + -- Auth provider ID (see OIDCProvider.IDPID) + auth_provider TEXT NOT NULL, + -- When this connection was created, as a unix timestamp. + created_ts BIGINT NOT NULL, + + UNIQUE(external_id, auth_provider), + UNIQUE(localpart, external_id, auth_provider) +); + +-- This index allows efficient lookup of the local user by the external ID +CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider); +` + +const insertLocalpartExternalIDSQL = "" + + "INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)" + +const selectLocalpartExternalIDSQL = "" + + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +const deleteLocalpartExternalIDSQL = "" + + "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + +type localpartExternalIDStatements struct { + db *sql.DB + insertUserExternalIDStmt *sql.Stmt + selectUserExternalIDStmt *sql.Stmt + deleteUserExternalIDStmt *sql.Stmt +} + +func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) { + s := &localpartExternalIDStatements{ + db: db, + } + _, err := db.Exec(localpartExternalIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertUserExternalIDStmt, insertLocalpartExternalIDSQL}, + {&s.selectUserExternalIDStmt, selectLocalpartExternalIDSQL}, + {&s.deleteUserExternalIDStmt, deleteLocalpartExternalIDSQL}, + }.Prepare(db) +} + +// Select selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { + ret := api.LocalpartExternalID{ + ExternalID: externalID, + AuthProvider: authProvider, + } + err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( + &ret.Localpart, &ret.CreatedTs, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + log.WithError(err).Error("Unable to retrieve localpart from the db") + return nil, err + } + + return &ret, nil +} + +// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) + return err +} + +// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { + stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) + _, err := stmt.ExecContext(ctx, externalID, authProvider) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 6d906191f..80ecaf83c 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -94,6 +94,10 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + localpartExternalIDsTable, err := NewSQLiteLocalpartExternalIDsTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteUserExternalIDsTable: %w", err) + } m = sqlutil.NewMigrator(db) m.AddMigrations(sqlutil.Migration{ @@ -119,6 +123,7 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert Pushers: pusherTable, Notifications: notificationsTable, Stats: statsTable, + LocalpartExternalIDs: localpartExternalIDsTable, ServerName: serverName, DB: db, Writer: writer, diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 44f31a5c5..7b141629a 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -127,6 +127,16 @@ type StatsTable interface { UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } +type LocalpartExternalIDsTable interface { + Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) + Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error + Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error +} + +type UIAuthSessionsTable interface { + SelectByID(ctx context.Context, txn *sql.Tx, sessionID int) (*api.UIAuthSession, error) +} + type NotificationFilter uint32 const ( From 63a199cec35b307c29229083208961949d9deb3e Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 29 Dec 2024 23:53:37 +0000 Subject: [PATCH 05/71] mas: first successful attempt of login with via mas --- clientapi/admin_test.go | 42 +- clientapi/auth/auth.go | 47 -- clientapi/auth/authtypes/logintypes.go | 1 + clientapi/auth/default_user_verifier.go | 59 +++ clientapi/clientapi.go | 7 +- clientapi/routing/admin.go | 161 +++++++ clientapi/routing/key_crosssigning.go | 113 ++++- clientapi/routing/password.go | 1 + clientapi/routing/register.go | 6 +- clientapi/routing/routing.go | 260 +++++----- internal/httputil/httpapi.go | 26 +- mediaapi/mediaapi.go | 5 +- mediaapi/routing/routing.go | 10 +- setup/config/config_mscs.go | 9 +- setup/monolith.go | 30 +- setup/mscs/msc2836/msc2836.go | 4 +- setup/mscs/msc3861/msc3861.go | 17 + setup/mscs/msc3861/msc3861_user_verifier.go | 444 ++++++++++++++++++ setup/mscs/mscs.go | 6 +- syncapi/routing/routing.go | 25 +- syncapi/syncapi.go | 2 + userapi/api/api.go | 32 ++ userapi/internal/cross_signing.go | 16 +- userapi/internal/key_api.go | 13 + userapi/internal/user_api.go | 19 +- userapi/storage/interface.go | 3 +- .../postgres/cross_signing_keys_table.go | 60 ++- userapi/storage/shared/storage.go | 15 +- .../sqlite3/cross_signing_keys_table.go | 64 ++- userapi/storage/tables/interface.go | 7 +- userapi/types/storage.go | 7 +- 31 files changed, 1224 insertions(+), 287 deletions(-) create mode 100644 clientapi/auth/default_user_verifier.go create mode 100644 setup/mscs/msc3861/msc3861.go create mode 100644 setup/mscs/msc3861/msc3861_user_verifier.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index d3c5bcee0..179e91407 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -27,6 +27,7 @@ import ( "github.com/tidwall/gjson" capi "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" "github.com/element-hq/dendrite/userapi" @@ -48,7 +49,8 @@ func TestAdminCreateToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -199,7 +201,8 @@ func TestAdminListRegistrationTokens(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -317,7 +320,8 @@ func TestAdminGetRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -418,7 +422,8 @@ func TestAdminDeleteRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -512,7 +517,8 @@ func TestAdminUpdateRegistrationToken(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, bob: {}, @@ -697,7 +703,8 @@ func TestAdminResetPassword(t *testing.T) { // Needed for changing the password/login userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -801,8 +808,9 @@ func TestPurgeRoom(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -872,8 +880,10 @@ func TestAdminEvacuateRoom(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -976,8 +986,10 @@ func TestAdminEvacuateUser(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1059,8 +1071,10 @@ func TestAdminMarkAsStale(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1147,8 +1161,10 @@ func TestAdminQueryEventReports(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -1376,8 +1392,10 @@ func TestEventReportsGetDelete(t *testing.T) { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index c32ed0fae..4e3612ce1 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -16,8 +16,6 @@ import ( "strings" "github.com/element-hq/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/matrix-org/util" ) // OWASP recommends at least 128 bits of entropy for tokens: https://www.owasp.org/index.php/Insufficient_Session-ID_Length @@ -37,51 +35,6 @@ type AccountDatabase interface { GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } -// VerifyUserFromRequest authenticates the HTTP request, -// on success returns Device of the requester. -// Finds local user or an application service user. -// Note: For an AS user, AS dummy device is returned. -// On failure returns an JSON error response which can be sent to the client. -func VerifyUserFromRequest( - req *http.Request, userAPI api.QueryAcccessTokenAPI, -) (*api.Device, *util.JSONResponse) { - // Try to find the Application Service user - token, err := ExtractAccessToken(req) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: spec.MissingToken(err.Error()), - } - } - var res api.QueryAccessTokenResponse - err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ - AccessToken: token, - AppServiceUserID: req.URL.Query().Get("user_id"), - }, &res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") - return nil, &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if res.Err != "" { - if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison - return nil, &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden(res.Err), - } - } - } - if res.Device == nil { - return nil, &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: spec.UnknownToken("Unknown token"), - } - } - return res.Device, nil -} - // GenerateAccessToken creates a new access token. Returns an error if failed to generate // random bytes. func GenerateAccessToken() (string, error) { diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f80..c6e67f315 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,5 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeCrossSigningReset = "org.matrix.cross_signing_reset" ) diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go new file mode 100644 index 000000000..f0a48f518 --- /dev/null +++ b/clientapi/auth/default_user_verifier.go @@ -0,0 +1,59 @@ +package auth + +import ( + "net/http" + "strings" + + "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +// DefaultUserVerifier implements UserVerifier interface +type DefaultUserVerifier struct { + UserAPI api.QueryAcccessTokenAPI +} + +// VerifyUserFromRequest authenticates the HTTP request, +// on success returns Device of the requester. +// Finds local user or an application service user. +// Note: For an AS user, AS dummy device is returned. +// On failure returns an JSON error response which can be sent to the client. +func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { + util.GetLogger(req.Context()).Debug("Default VerifyUserFromRequest") + // Try to find the Application Service user + token, err := ExtractAccessToken(req) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + var res api.QueryAccessTokenResponse + err = d.UserAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ + AccessToken: token, + AppServiceUserID: req.URL.Query().Get("user_id"), + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if res.Err != "" { + if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(res.Err), + } + } + } + if res.Device == nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + } + } + return res.Device, nil +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index dbf862ca6..1c3bc4711 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -36,7 +36,9 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - extRoomsProvider api.ExtraPublicRoomsProvider, enableMetrics bool, + extRoomsProvider api.ExtraPublicRoomsProvider, + userVerifier httputil.UserVerifier, + enableMetrics bool, ) { js, natsClient := natsInstance.Prepare(processContext, &cfg.Global.JetStream) @@ -55,6 +57,7 @@ func AddPublicRoutes( cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, - extRoomsProvider, natsClient, enableMetrics, + extRoomsProvider, natsClient, + userVerifier, enableMetrics, ) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 408811661..0b07724ab 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -21,6 +21,7 @@ import ( "golang.org/x/exp/constraints" clientapi "github.com/element-hq/dendrite/clientapi/api" + clienthttputil "github.com/element-hq/dendrite/clientapi/httputil" "github.com/element-hq/dendrite/internal/httputil" roomserverAPI "github.com/element-hq/dendrite/roomserver/api" "github.com/element-hq/dendrite/setup/config" @@ -517,6 +518,166 @@ func AdminCheckUsernameAvailable( } } +func AdminHandleUserDeviceByUserID( + req *http.Request, + userAPI userapi.ClientUserAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, ok := vars["userID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam("Expecting user ID."), + } + } + + logger := util.GetLogger(req.Context()) + + switch req.Method { + case http.MethodPost: + local, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(userID), + } + } + var payload struct { + DeviceID string `json:"device_id"` + } + if resErr := clienthttputil.UnmarshalJSONRequest(req, &payload); resErr != nil { + return *resErr + } + + var rs userapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: local, + ServerName: domain, + DeviceID: &payload.DeviceID, + IPAddr: "", + UserAgent: req.UserAgent(), + NoDeviceListUpdate: false, + FromRegistration: false, + }, &rs); err != nil { + logger.WithError(err).Debug("PerformDeviceCreation failed") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + logger.WithError(err).Debug("PerformDeviceCreation succeeded") + return util.JSONResponse{ + Code: http.StatusCreated, + JSON: struct{}{}, + } + case http.MethodGet: + var res userapi.QueryDevicesResponse + if err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &res); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + jsonDevices := make([]deviceJSON, 0, len(res.Devices)) + for i := range res.Devices { + d := &res.Devices[i] + jsonDevices = append(jsonDevices, deviceJSON{ + DeviceID: d.ID, + DisplayName: d.DisplayName, + LastSeenIP: d.LastSeenIP, + LastSeenTS: d.LastSeenTS, + }) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct { + Devices []deviceJSON `json:"devices"` + Total int `json:"total"` + }{ + Devices: jsonDevices, + Total: len(res.Devices), + }, + } + default: + return util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: struct{}{}, + } + } + +} + +type adminExternalID struct { + AuthProvider string `json:"auth_provider"` + ExternalID string `json:"external_id"` +} + +type adminCreateOrModifyAccountRequest struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` + // TODO: the following fields are not used here, but they are used in Synapse. Probably we should reproduce the logic of the + // endpoint fully compatible. + // Password string `json:"password"` + // LogoutDevices bool `json:"logout_devices"` + // Threepids json.RawMessage `json:"threepids"` + // ExternalIDs []adminExternalID `json:"external_ids"` + // Admin bool `json:"admin"` + // Deactivated bool `json:"deactivated"` + // Locked bool `json:"locked"` +} + +func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, ok := vars["userID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam("Expecting user ID."), + } + } + local, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(userID), + } + } + var r adminCreateOrModifyAccountRequest + if resErr := clienthttputil.UnmarshalJSONRequest(req, &r); resErr != nil { + logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr) + return *resErr + } + logger.Debugf("adminCreateOrModifyAccountRequest is: %+v", r) + statusCode := http.StatusOK + { + var res userapi.PerformAccountCreationResponse + err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: local, + ServerName: domain, + OnConflict: api.ConflictUpdate, + AvatarURL: r.AvatarURL, + DisplayName: r.DisplayName, + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Debugln("Failed creating account") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if res.AccountCreated { + statusCode = http.StatusCreated + } + } + + return util.JSONResponse{ + Code: statusCode, + JSON: nil, + } +} + // GetEventReports returns reported events for a given user/room. func GetEventReports( req *http.Request, diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index e6f093b5e..7bcd7093c 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -8,6 +8,8 @@ package routing import ( "net/http" + "strings" + "time" "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" @@ -39,28 +41,83 @@ func UploadCrossSigningDeviceKeys( if sessionID == "" { sessionID = util.RandomString(sessionIDLength) } - if uploadReq.Auth.Type != authtypes.LoginTypePassword { - return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, - }, - nil, - ), + + isCrossSigningSetup := false + masterKeyUpdatableWithoutUIA := false + { + var keysResp api.QueryMasterKeysResponse + keyserverAPI.QueryMasterKeys(req.Context(), &api.QueryMasterKeysRequest{UserID: device.UserID}, &keysResp) + if err := keysResp.Error; err != nil { + return convertKeyError(err) + } + if k := keysResp.Key; k != nil { + isCrossSigningSetup = true + if k.UpdatableWithoutUIABeforeMs != nil { + masterKeyUpdatableWithoutUIA = time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs + } } } - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI.QueryAccountByPassword, - Config: cfg, - } - if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { - return *authErr + + if isCrossSigningSetup { + // With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable. + if cfg.MSCs.MSC3861Enabled() { + if !masterKeyUpdatableWithoutUIA { + url := "" + if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { + url = strings.Join([]string{m.AccountManagementURL, "?action=", authtypes.LoginTypeCrossSigningReset}, "") + } else { + url = m.Issuer + } + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + "dummy", + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypeCrossSigningReset}, + }, + }, + map[string]interface{}{ + authtypes.LoginTypeCrossSigningReset: map[string]string{ + "url": url, + }, + }, + strings.Join([]string{ + "To reset your end-to-end encryption cross-signing, identity, you first need to approve it at", + url, + "and then try again.", + }, " "), + ), + } + } + // XXX: is it necessary? + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeCrossSigningReset) + } else { + if uploadReq.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + "", + ), + } + } + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountAPI.QueryAccountByPassword, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { + return *authErr + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + } } - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) @@ -108,7 +165,17 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) if err := uploadRes.Error; err != nil { - switch { + return convertKeyError(err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func convertKeyError(err *api.KeyError) util.JSONResponse { + switch { case err.IsInvalidSignature: return util.JSONResponse{ Code: http.StatusBadRequest, @@ -130,10 +197,4 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie JSON: spec.Unknown(err.Error()), } } - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 59d9594d6..6258155db 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -67,6 +67,7 @@ func Password( }, }, nil, + "", ), } } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 5544dccd3..7bcda2069 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -234,6 +234,7 @@ type userInteractiveResponse struct { Completed []authtypes.LoginType `json:"completed"` Params map[string]interface{} `json:"params"` Session string `json:"session"` + Msg string `json:"msg,omitempty"` } // newUserInteractiveResponse will return a struct to be sent back to the client @@ -242,9 +243,10 @@ func newUserInteractiveResponse( sessionID string, fs []authtypes.Flow, params map[string]interface{}, + msg string, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.getCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, msg, } } @@ -817,7 +819,7 @@ func checkAndCompleteFlow( return util.JSONResponse{ Code: http.StatusUnauthorized, JSON: newUserInteractiveResponse(sessionID, - cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params), + cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params, ""), } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 73cfcfc3b..ed93d0796 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -67,7 +67,9 @@ func Setup( transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - natsClient *nats.Conn, enableMetrics bool, + natsClient *nats.Conn, + userVerifier httputil.UserVerifier, + enableMetrics bool, ) { cfg := &dendriteCfg.ClientAPI mscCfg := &dendriteCfg.MSCs @@ -171,19 +173,19 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } dendriteAdminRouter.Handle("/admin/registrationTokens/new", - httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_registration_tokens_new", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminCreateNewRegistrationToken(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens", - httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_list_registration_tokens", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminListRegistrationTokens(req, cfg, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", - httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_get_registration_token", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { switch req.Method { case http.MethodGet: return AdminGetRegistrationToken(req, cfg, userAPI) @@ -202,43 +204,43 @@ func Setup( ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", - httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_evacuate_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", - httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_evacuate_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateUser(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", - httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_purge_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminPurgeRoom(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", - httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_reset_password", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}", - httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_download_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminDownloadState(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/fulltext/reindex", - httputil.MakeAdminAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_fultext_reindex", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminReindex(req, cfg, device, natsClient) }), ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", - httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_refresh_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -252,7 +254,7 @@ func Setup( } synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", - httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_server_notice", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -273,7 +275,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/send_server_notice", - httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_server_notice", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -301,12 +303,12 @@ func Setup( unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() v3mux.Handle("/createRoom", - httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("createRoom", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/join/{roomIDOrAlias}", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -340,11 +342,21 @@ func Setup( httputil.MakeServiceAdminAPI("admin_username_available", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminCheckUsernameAvailable(r, userAPI, cfg) })).Methods(http.MethodGet) + + synapseAdminRouter.Handle("/admin/v2/users/{userID}", + httputil.MakeServiceAdminAPI("admin_provision_user", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminCreateOrModifyAccount(r, userAPI) + })).Methods(http.MethodPut) + + synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", + httputil.MakeServiceAdminAPI("admin_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminHandleUserDeviceByUserID(r, userAPI) + })).Methods(http.MethodPost, http.MethodGet) } if mscCfg.Enabled("msc2753") { v3mux.Handle("/peek/{roomIDOrAlias}", - httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Peek, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -359,12 +371,12 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) } v3mux.Handle("/joined_rooms", - httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("joined_rooms", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -386,7 +398,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -400,7 +412,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unpeek", - httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("unpeek", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -411,7 +423,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/ban", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -420,7 +432,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/invite", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -432,7 +444,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/kick", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -441,7 +453,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unban", - httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("membership", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -450,7 +462,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -459,7 +471,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -470,7 +482,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -478,7 +490,7 @@ func Setup( return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -486,7 +498,7 @@ func Setup( return GetAliases(req, rsAPI, device, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -497,7 +509,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -507,7 +519,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -519,7 +531,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", - httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_message", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -533,7 +545,7 @@ func Setup( // TODO: clear based on some criteria roomHierarchyPaginationCache := NewRoomHierarchyPaginationCache() v1mux.Handle("/rooms/{roomID}/hierarchy", - httputil.MakeAuthAPI("spaces", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("spaces", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -567,7 +579,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/directory/room/{roomAlias}", - httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -577,7 +589,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/directory/room/{roomAlias}", - httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -596,7 +608,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/directory/list/room/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -605,7 +617,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -616,7 +628,7 @@ func Setup( // Undocumented endpoint v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", - httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("directory_list", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -632,19 +644,19 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/logout", - httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/logout/all", - httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/typing/{userID}", - httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_typing", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -656,7 +668,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}", - httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_redact", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -665,7 +677,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", - httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_redact", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -676,7 +688,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/sendToDevice/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_to_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -690,7 +702,7 @@ func Setup( // rather than r0. It's an exact duplicate of the above handler. // TODO: Remove this if/when sytest is fixed! unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", - httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("send_to_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -701,7 +713,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", - httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("whoami", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -710,7 +722,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", - httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("password", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -719,7 +731,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/account/deactivate", - httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("deactivate", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -739,7 +751,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTTPAPI("auth_fallback", userAPI, enableMetrics, func(w http.ResponseWriter, req *http.Request) { + httputil.MakeHTTPAPI("auth_fallback", userVerifier, enableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) AuthFallback(w, req, vars["authType"], cfg) }), @@ -748,7 +760,7 @@ func Setup( // Push rules v3mux.Handle("/pushrules", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash"), @@ -757,13 +769,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAllPushRules(req.Context(), device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("scope, kind and rule ID must be specified"), @@ -772,7 +784,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -782,7 +794,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash after scope"), @@ -791,7 +803,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope:[^/]+/?}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("kind and rule ID must be specified"), @@ -800,7 +812,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -810,7 +822,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("missing trailing slash after kind"), @@ -819,7 +831,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidParam("rule ID must be specified"), @@ -828,7 +840,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -838,7 +850,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -852,7 +864,7 @@ func Setup( ).Methods(http.MethodPut) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -862,7 +874,7 @@ func Setup( ).Methods(http.MethodDelete) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -872,7 +884,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}", - httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("push_rules", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -904,7 +916,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/profile/{userID}/avatar_url", - httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("profile_avatar_url", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -929,7 +941,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/profile/{userID}/displayname", - httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("profile_displayname", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -946,19 +958,19 @@ func Setup( threePIDClient := base.CreateClient(dendriteCfg, nil) // TODO: Move this somewhere else, e.g. pass in as parameter v3mux.Handle("/account/3pid", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/3pid", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, userAPI, device, cfg, threePIDClient) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/account/3pid/delete", - httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("account_3pid", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Forget3PID(req, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -970,7 +982,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/voip/turnServer", - httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("turn_server", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -979,13 +991,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocols", - httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_protocols", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Protocols(req, asAPI, device, "") }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocol/{protocolID}", - httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_protocols", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -995,7 +1007,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user/{protocolID}", - httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1005,13 +1017,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user", - httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_user", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return User(req, asAPI, device, "", req.URL.Query()) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location/{protocolID}", - httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_location", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1021,7 +1033,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location", - httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("thirdparty_location", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Location(req, asAPI, device, "", req.URL.Query()) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) @@ -1037,7 +1049,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1047,7 +1059,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1057,7 +1069,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1067,7 +1079,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", - httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("user_account_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1077,7 +1089,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/admin/whois/{userID}", - httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("admin_whois", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1087,7 +1099,7 @@ func Setup( ).Methods(http.MethodGet) v3mux.Handle("/user/{userID}/openid/request_token", - httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("openid_request_token", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1100,7 +1112,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/user_directory/search", - httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("userdirectory_search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1126,7 +1138,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/read_markers", - httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_read_markers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1139,7 +1151,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/forget", - httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_forget", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1152,7 +1164,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/upgrade", - httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_upgrade", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1162,13 +1174,13 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/devices", - httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1178,7 +1190,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("device_data", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1188,7 +1200,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", - httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_device", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1198,25 +1210,25 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) v3mux.Handle("/delete_devices", - httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_devices", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return DeleteDevices(req, userInteractiveAuth, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/notifications", - httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_notifications", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetNotifications(req, device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushers", - httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_pushers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetPushers(req, device, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/pushers/set", - httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("set_pushers", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1226,7 +1238,7 @@ func Setup( // Stub implementations for sytest v3mux.Handle("/events", - httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("events", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", @@ -1236,7 +1248,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/initialSync", - httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("initial_sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} @@ -1244,7 +1256,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", - httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_tags", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1254,7 +1266,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("put_tag", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1264,7 +1276,7 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", - httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("delete_tag", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1274,7 +1286,7 @@ func Setup( ).Methods(http.MethodDelete, http.MethodOptions) v3mux.Handle("/capabilities", - httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("capabilities", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1284,7 +1296,7 @@ func Setup( // Key Backup Versions (Metadata) - getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysVersion := httputil.MakeAuthAPI("get_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1292,11 +1304,11 @@ func Setup( return KeyBackupVersion(req, userAPI, device, vars["version"]) }) - getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return KeyBackupVersion(req, userAPI, device, "") }) - putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1304,7 +1316,7 @@ func Setup( return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) }) - deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1312,7 +1324,7 @@ func Setup( return DeleteKeyBackupVersion(req, userAPI, device, vars["version"]) }) - postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateKeyBackupVersion(req, userAPI, device) }) @@ -1331,7 +1343,7 @@ func Setup( // Inserting E2E Backup Keys // Bulk room and session - putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeys := httputil.MakeAuthAPI("put_backup_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { version := req.URL.Query().Get("version") if version == "" { return util.JSONResponse{ @@ -1348,7 +1360,7 @@ func Setup( }) // Single room bulk session - putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1380,7 +1392,7 @@ func Setup( }) // Single room, single session - putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1422,11 +1434,11 @@ func Setup( // Querying E2E Backup Keys - getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") }) - getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1434,7 +1446,7 @@ func Setup( return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "") }) - getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + getBackupKeysRoomSession := httputil.MakeAuthAPI("get_backup_keys_room_session", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1454,11 +1466,11 @@ func Setup( // Cross-signing device keys - postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) }) - postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) @@ -1470,27 +1482,27 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", - httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_upload", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", - httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_upload", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", - httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_query", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", - httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("keys_claim", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", - httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -1503,7 +1515,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/presence/{userId}/status", - httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("set_presence", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1512,7 +1524,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/presence/{userId}/status", - httputil.MakeAuthAPI("get_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_presence", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1522,7 +1534,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/joined_members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1532,7 +1544,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/report/{eventID}", - httputil.MakeAuthAPI("report_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("report_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1542,7 +1554,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports", - httputil.MakeAdminAPI("admin_report_events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_events", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { from := parseUint64OrDefault(req.URL.Query().Get("from"), 0) limit := parseUint64OrDefault(req.URL.Query().Get("limit"), 100) dir := req.URL.Query().Get("dir") @@ -1556,7 +1568,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports/{reportID}", - httputil.MakeAdminAPI("admin_report_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -1566,7 +1578,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) synapseAdminRouter.Handle("/admin/v1/event_reports/{reportID}", - httputil.MakeAdminAPI("admin_report_event_delete", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_report_event_delete", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 5a332d6fa..f04c2bd4f 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -58,17 +58,23 @@ func WithAuth() AuthAPIOption { } } +type UserVerifier interface { + // VerifyUserFromRequest authenticates the HTTP request, + // on success returns Device of the requester. + VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( - metricsName string, userAPI userapi.QueryAcccessTokenAPI, + metricsName string, userVerifier UserVerifier, f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPIOption, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) - device, err := auth.VerifyUserFromRequest(req, userAPI) + device, err := userVerifier.VerifyUserFromRequest(req) if err != nil { - logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d: JSON %+v", req.RemoteAddr, err.Code, err.JSON) return *err } // add the user ID to the logger @@ -122,11 +128,11 @@ func MakeAuthAPI( // MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be // completed by a user that is a server administrator. func MakeAdminAPI( - metricsName string, userAPI userapi.QueryAcccessTokenAPI, + metricsName string, userVerifier UserVerifier, f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { - return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - if device.AccountType != userapi.AccountTypeAdmin { + return MakeAuthAPI(metricsName, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if device == nil || device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("This API can only be used by admin users."), @@ -136,8 +142,8 @@ func MakeAdminAPI( }) } -// MakeServiceAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be -// completed by a trusted service e.g. Matrix Auth Service. +// MakeServiceAdminAPI is a wrapper around MakeExternalAPI which enforces that the request can only be +// completed by a trusted service e.g. Matrix Auth Service (MAS). func MakeServiceAdminAPI( metricsName, serviceToken string, f func(*http.Request) util.JSONResponse, @@ -232,7 +238,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTTPAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler { +func MakeHTTPAPI(metricsName string, userVerifier UserVerifier, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { if req.Method == http.MethodOptions { util.SetCORSHeaders(w) @@ -252,7 +258,7 @@ func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enabl if opts.WithAuth { logger := util.GetLogger(req.Context()) - _, jsonErr := auth.VerifyUserFromRequest(req, userAPI) + _, jsonErr := userVerifier.VerifyUserFromRequest(req) if jsonErr != nil { w.WriteHeader(jsonErr.Code) if err := json.NewEncoder(w).Encode(jsonErr.JSON); err != nil { diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index ac20c886f..9793d8401 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -12,7 +12,6 @@ import ( "github.com/element-hq/dendrite/mediaapi/routing" "github.com/element-hq/dendrite/mediaapi/storage" "github.com/element-hq/dendrite/setup/config" - userapi "github.com/element-hq/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" @@ -23,10 +22,10 @@ func AddPublicRoutes( routers httputil.Routers, cm *sqlutil.Connections, cfg *config.Dendrite, - userAPI userapi.MediaUserAPI, client *fclient.Client, fedClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, ) { mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) if err != nil { @@ -34,6 +33,6 @@ func AddPublicRoutes( } routing.Setup( - routers, cfg, mediaDB, userAPI, client, fedClient, keyRing, + routers, cfg, mediaDB, client, fedClient, keyRing, userVerifier, ) } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 45da8eba6..3d198f0d0 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -42,10 +42,10 @@ func Setup( routers httputil.Routers, cfg *config.Dendrite, db storage.Database, - userAPI userapi.MediaUserAPI, client *fclient.Client, federationClient fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, ) { rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) @@ -58,7 +58,7 @@ func Setup( } uploadHandler := httputil.MakeAuthAPI( - "upload", userAPI, + "upload", userVerifier, func(req *http.Request, dev *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, dev); r != nil { return *r @@ -67,7 +67,7 @@ func Setup( }, ) - configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + configHandler := httputil.MakeAuthAPI("config", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -97,13 +97,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) // v1 client endpoints requiring auth - downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()) + downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()) v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) v1mux.Handle("/thumbnail/{serverName}/{mediaId}", - httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), + httputil.MakeHTTPAPI("thumbnail", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), ).Methods(http.MethodGet, http.MethodOptions) // same, but for federation diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index d6a51b651..1523a9ce8 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -1,15 +1,18 @@ package config +import "slices" + type MSCs struct { Matrix *Global `yaml:"-"` // The MSCs to enable. Supported MSCs include: + // 'msc3861': Delegate auth to an OIDC provider. This line MUST always go first if the msc is used https://github.com/matrix-org/matrix-spec-proposals/pull/3861 // 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444 // 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753 // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 - // 'msc3861': Delegate auth to an OIDC provider https://github.com/matrix-org/matrix-spec-proposals/pull/3861 MSCs []string `yaml:"mscs"` + // MSC3861 contains config related to the experimental feature MSC3861. It takes effect only if 'msc3861' is included in 'MSCs' array MSC3861 *MSC3861 `yaml:"msc3861,omitempty"` Database DatabaseOptions `yaml:"database,omitempty"` @@ -42,6 +45,10 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { } } +func (c *MSCs) MSC3861Enabled() bool { + return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil && c.MSC3861.Enabled +} + type MSC3861 struct { Enabled bool `yaml:"enabled"` Issuer string `yaml:"issuer"` diff --git a/setup/monolith.go b/setup/monolith.go index 36d6794d6..8d8fadc90 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -7,9 +7,12 @@ package setup import ( + "net/http" + appserviceAPI "github.com/element-hq/dendrite/appservice/api" "github.com/element-hq/dendrite/clientapi" "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/federationapi" federationAPI "github.com/element-hq/dendrite/federationapi/api" "github.com/element-hq/dendrite/internal/caching" @@ -27,6 +30,7 @@ import ( userapi "github.com/element-hq/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/util" ) // Monolith represents an instantiation of all dependencies required to build @@ -46,6 +50,8 @@ type Monolith struct { // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI + + UserVerifierProvider *UserVerifierProvider } // AddAllPublicRoutes attaches all public paths to the given router @@ -58,6 +64,10 @@ func (m *Monolith) AddAllPublicRoutes( caches *caching.Caches, enableMetrics bool, ) { + if m.UserVerifierProvider == nil { + m.UserVerifierProvider = NewUserVerifierProvider(&auth.DefaultUserVerifier{UserAPI: m.UserAPI}) + } + userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI @@ -65,15 +75,29 @@ func (m *Monolith) AddAllPublicRoutes( clientapi.AddPublicRoutes( processCtx, routers, cfg, natsInstance, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, - m.ExtPublicRoomsProvider, enableMetrics, + m.ExtPublicRoomsProvider, m.UserVerifierProvider, enableMetrics, ) federationapi.AddPublicRoutes( processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, ) - mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) + mediaapi.AddPublicRoutes(routers, cm, cfg, m.Client, m.FedClient, m.KeyRing, m.UserVerifierProvider) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, m.UserVerifierProvider, enableMetrics) if m.RelayAPI != nil { relayapi.AddPublicRoutes(routers, cfg, m.KeyRing, m.RelayAPI) } } + +type UserVerifierProvider struct { + UserVerifier httputil.UserVerifier +} + +func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { + return u.UserVerifier.VerifyUserFromRequest(req) +} + +func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider { + return &UserVerifierProvider{ + UserVerifier: userVerifier, + } +} diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 4322e8a2b..847e836ab 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -98,7 +98,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons // Enable this MSC func Enable( cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, - userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, + userVerifier httputil.UserVerifier, keyRing gomatrixserverlib.JSONVerifier, ) error { db, err := NewDatabase(cm, &cfg.MSCs.Database) if err != nil { @@ -124,7 +124,7 @@ func Enable( }) routers.Client.Handle("/unstable/event_relationships", - httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)), + httputil.MakeAuthAPI("eventRelationships", userVerifier, eventRelationshipHandler(db, rsAPI, fsAPI)), ).Methods(http.MethodPost, http.MethodOptions) routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go new file mode 100644 index 000000000..9b38af31f --- /dev/null +++ b/setup/mscs/msc3861/msc3861.go @@ -0,0 +1,17 @@ +package msc3861 + +import ( + "github.com/element-hq/dendrite/setup" +) + +func Enable(m *setup.Monolith) error { + userVerifier, err := newMSC3861UserVerifier( + m.UserAPI, m.Config.Global.ServerName, + m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, + ) + if err != nil { + return err + } + m.UserVerifierProvider.UserVerifier = userVerifier + return nil +} diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go new file mode 100644 index 000000000..fcf5bb396 --- /dev/null +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -0,0 +1,444 @@ +package msc3861 + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/element-hq/dendrite/clientapi/auth" + "github.com/element-hq/dendrite/setup/config" + "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +const externalAuthProvider string = "oauth-delegated" + +// Scopes as defined by MSC2967 +// https://github.com/matrix-org/matrix-spec-proposals/pull/2967 +const ( + scopeMatrixAPI string = "urn:matrix:org.matrix.msc2967.client:api:*" + scopeMatrixGuest string = "urn:matrix:org.matrix.msc2967.client:api:guest" + scopeMatrixDevicePrefix string = "urn:matrix:org.matrix.msc2967.client:device:" +) + +type errCode string + +const ( + codeIntrospectionNot2xx errCode = "introspectionIsNot2xx" + codeInvalidClientToken errCode = "invalidClientToken" + codeAuthError errCode = "authError" + codeMxidError errCode = "mxidError" + codeOpenidConfigEndpointNon2xx errCode = "openidConfigEndpointNon2xx" + codeOpenidConfigDecodingFailed errCode = "openidConfigDecodingFailed" +) + +// MSC3861UserVerifier implements UserVerifier interface +type MSC3861UserVerifier struct { + userAPI api.UserInternalAPI + serverName spec.ServerName + cfg *config.MSC3861 + httpClient *http.Client + openIdConfig *OpenIDConfiguration + allowGuest bool +} + +func newMSC3861UserVerifier( + userAPI api.UserInternalAPI, + serverName spec.ServerName, + cfg *config.MSC3861, + allowGuest bool, +) (*MSC3861UserVerifier, error) { + openIdConfig, err := fetchOpenIDConfiguration(&http.Client{}, cfg.Issuer) + if err != nil { + return nil, err + } + return &MSC3861UserVerifier{ + userAPI: userAPI, + serverName: serverName, + cfg: cfg, + openIdConfig: openIdConfig, + allowGuest: allowGuest, + httpClient: http.DefaultClient, + }, nil +} + +type mscError struct { + Code errCode + Msg string +} + +func (r *mscError) Error() string { + return fmt.Sprintf("%s: %s", r.Code, r.Msg) +} + +// VerifyUserFromRequest authenticates the HTTP request, on success returns Device of the requester. +func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { + util.GetLogger(req.Context()).Debug("MSC3861.VerifyUserFromRequest") + // Try to find the Application Service user + token, err := auth.ExtractAccessToken(req) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.MissingToken(err.Error()), + } + } + // TODO: try to get appservice user first. See https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/msc3861_delegated.py#L273 + userData, err := m.getUserByAccessToken(req.Context(), token) + if err != nil { + switch e := err.(type) { + case (*mscError): + switch e.Code { + case codeIntrospectionNot2xx, codeOpenidConfigDecodingFailed, codeOpenidConfigEndpointNon2xx: + return nil, &util.JSONResponse{ + Code: http.StatusServiceUnavailable, + JSON: spec.Unknown(e.Error()), + } + case codeInvalidClientToken: + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Forbidden(e.Error()), + } + case codeAuthError: + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(e.Error()), + } + case codeMxidError: + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(e.Error()), + } + default: + r := util.ErrorResponse(err) + return nil, &r + } + default: + r := util.ErrorResponse(err) + return nil, &r + } + } + + // Do not record requests from MAS using the virtual `__oidc_admin` user. + if token != m.cfg.AdminToken { + // TODO: not sure which exact data we should record here. See the link for reference + // https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/base.py#L365 + } + + if !m.allowGuest && userData.IsGuest { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.Forbidden(strings.Join([]string{"Insufficient scope: ", scopeMatrixAPI}, "")), + } + } + + return userData.Device, nil +} + +type requester struct { + Device *api.Device + UserID *spec.UserID + Scope []string + IsGuest bool +} + +func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token string) (*requester, error) { + var userID *spec.UserID + logger := util.GetLogger(ctx) + + if adminToken := m.cfg.AdminToken; adminToken != "" && token == adminToken { + // XXX: This is a temporary solution so that the admin API can be called by + // the OIDC provider. This will be removed once we have OIDC client + // credentials grant support in matrix-authentication-service. + logger.Info("Admin token used") + // XXX: that user doesn't exist and won't be provisioned. + adminUser, err := createUserID("__oidc_admin", m.serverName) + if err != nil { + return nil, err + } + return &requester{ + UserID: adminUser, + Scope: []string{"urn:synapse:admin:*"}, + Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeAdmin}, + }, nil + } + + introspectionResult, err := m.introspectToken(ctx, token) + if err != nil { + logger.WithError(err).Error("MSC3861UserVerifier:introspectToken") + return nil, err + } + logger.Debugf("Introspection result: %+v", *introspectionResult) + + if !introspectionResult.Active { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "Token is not active"} + } + + scopes := introspectionResult.Scopes() + hasUserScope, hasGuestScope := slices.Contains(scopes, scopeMatrixAPI), slices.Contains(scopes, scopeMatrixGuest) + if !hasUserScope && !hasGuestScope { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "No scope in token granting user rights"} + } + + sub := introspectionResult.Sub + if sub == "" { + return nil, &mscError{Code: codeInvalidClientToken, Msg: "Invalid sub claim in the introspection result"} + } + + localpart := "" + { + var rs api.QueryLocalpartExternalIDResponse + if err = m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, &api.QueryLocalpartExternalIDRequest{ + ExternalID: sub, + AuthProvider: externalAuthProvider, + }, &rs); err != nil && err != sql.ErrNoRows { + return nil, err + } + if l := rs.LocalpartExternalID; l != nil { + localpart = l.Localpart + } + } + + if localpart == "" { + // If we could not find a user via the external_id, it either does not exist, + // or the external_id was never recorded + username := introspectionResult.Username + if username == "" { + return nil, &mscError{Code: codeAuthError, Msg: "Invalid username claim in the introspection result"} + } + userID, err = createUserID(username, m.serverName) + if err != nil { + logger.WithError(err).Error("getUserByAccessToken:createUserID") + return nil, err + } + + // First try to find a user from the username claim + var account *api.Account + { + var rs api.QueryAccountByLocalpartResponse + err := m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs) + if err != nil && err != sql.ErrNoRows { + logger.WithError(err).Error("QueryAccountByLocalpart") + return nil, err + } + account = rs.Account + } + + if account == nil { + // If the user does not exist, we should create it on the fly + var rs api.PerformAccountCreationResponse + if err = m.userAPI.PerformAccountCreation(ctx, &api.PerformAccountCreationRequest{ + AccountType: api.AccountTypeUser, + Localpart: userID.Local(), + ServerName: userID.Domain(), + }, &rs); err != nil { + logger.WithError(err).Error("PerformAccountCreation") + return nil, err + } + } + + if err := m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ + Localpart: userID.Local(), + ExternalID: sub, + AuthProvider: externalAuthProvider, + }); err != nil { + logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation") + return nil, err + } + + localpart = userID.Local() + } + + if userID == nil { + userID, err = createUserID(localpart, m.serverName) + if err != nil { + logger.WithError(err).Error("getUserByAccessToken:createUserID") + return nil, err + } + } + + deviceIDs := make([]string, 0, 1) + for i := range scopes { + if s := scopes[i]; strings.HasPrefix(s, scopeMatrixDevicePrefix) { + deviceIDs = append(deviceIDs, s[len(scopeMatrixDevicePrefix):]) + } + } + + if len(deviceIDs) != 1 { + logger.Errorf("Invalid device IDs in scope: %+v", deviceIDs) + return nil, &mscError{Code: codeAuthError, Msg: "Invalid device IDs in scope"} + } + + var device *api.Device + + deviceID := deviceIDs[0] + if len(deviceID) > 255 || len(deviceID) < 1 { + return nil, &mscError{ + Code: codeAuthError, + Msg: strings.Join([]string{"Invalid device ID in scope: ", deviceID}, ""), + } + } + logger.Debugf("deviceID is: %s", deviceID) + logger.Debugf("scope is: %+v", scopes) + + userDeviceExists := false + { + var rs api.QueryDevicesResponse + err := m.userAPI.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: userID.String()}, &rs) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + + for i := range rs.Devices { + if d := &rs.Devices[i]; d.ID == deviceID { + userDeviceExists = true + device = d + break + } + } + } + logger.Debugf("userDeviceExists is: %t", userDeviceExists) + if !userDeviceExists { + var rs api.PerformDeviceCreationResponse + deviceDisplayName := "OIDC-native client" + if err := m.userAPI.PerformDeviceCreation(ctx, &api.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: m.serverName, + AccessToken: token, + DeviceID: &deviceID, + DeviceDisplayName: &deviceDisplayName, + // TODO: Cannot add IPAddr and Useragent values here. Should we care about it here? + }, &rs); err != nil { + logger.WithError(err).Error("PerformDeviceCreation") + return nil, err + } + device = rs.Device + logger.Debugf("PerformDeviceCreationResponse is: %+v", rs) + } + + + + return &requester{ + Device: device, + UserID: userID, + Scope: scopes, + IsGuest: hasGuestScope && !hasUserScope, + }, nil +} + +func createUserID(local string, serverName spec.ServerName) (*spec.UserID, error) { + userID, err := spec.NewUserID(strings.Join([]string{"@", local, ":", string(serverName)}, ""), false) + if err != nil { + return nil, &mscError{Code: codeMxidError, Msg: err.Error()} + } + return userID, nil +} + +func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) (*introspectionResponse, error) { + formBody := url.Values{"token": []string{token}} + encoded := formBody.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.openIdConfig.IntrospectionEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, err + } + body := resp.Body + defer resp.Body.Close() + + if c := resp.StatusCode; c < 200 || c >= 300 { + return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status ,"' response"}, "")) + } + var ir introspectionResponse + if err := json.NewDecoder(body).Decode(&ir); err != nil { + return nil, err + } + return &ir, nil +} + +type OpenIDConfiguration struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKsURI string `json:"jwks_uri"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + ResponseModesSupported []string `json:"response_modes_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + TokenEndpointAuthSigningAlgCaluesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported"` + RevocationEnpoint string `json:"revocation_endpoint"` + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported"` + RevocationEndpointAuthSigningAlgValues []string `json:"revocation_endpoint_auth_signing_alg_values_supported"` + IntrospectionEndpoint string `json:"introspection_endpoint"` + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"` + IntrospectionEndpointAuthSigningAlgValues []string `json:"introspection_endpoint_auth_signing_alg_values_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported"` + DisplayValuesSupported []string `json:"display_values_supported"` + ClaimTypesSupported []string `json:"claim_types_supported"` + ClaimsSupported []string `json:"claims_supported"` + ClaimsParameterSupported bool `json:"claims_parameter_supported"` + RequestParameterSupported bool `json:"request_parameter_supported"` + RequestURIParameterSupported bool `json:"request_uri_parameter_supported"` + PromptValuesSupported []string `json:"prompt_values_supported"` + DeviceAuthorizaEndpoint string `json:"device_authorization_endpoint"` + AccountManagementURI string `json:"account_management_uri"` + AccountManagementActionsSupported []string `json:"account_management_actions_supported"` +} + +func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (* + OpenIDConfiguration, error) { + u, err := url.Parse(authHostURL) + if err != nil { + return nil, err + } + u = u.JoinPath(".well-known/openid-configuration") + resp, err := httpClient.Get(u.String()) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, &mscError{Code: codeOpenidConfigEndpointNon2xx, Msg: ".well-known/openid-configuration endpoint returned non-200 response"} + } + var oic OpenIDConfiguration + if err := json.NewDecoder(resp.Body).Decode(&oic); err != nil { + return nil, &mscError{Code: codeOpenidConfigDecodingFailed, Msg: err.Error()} + } + return &oic, nil +} + +// introspectionResponse as described in the RFC https://datatracker.ietf.org/doc/html/rfc7662#section-2.2 +type introspectionResponse struct { + Active bool `json:"active"` // required + Scope string `json:"scope"` // optional + Username string `json:"username"` // optional + TokenType string `json:"token_type"` // optional + Exp *int64 `json:"exp"` // optional + Iat *int64 `json:"iat"` // optional + Nfb *int64 `json:"nfb"` // optional + Sub string `json:"sub"` // optional + Jti string `json:"jti"` // optional + Aud string `json:"aud"` // optional + Iss string `json:"iss"` // optional +} + +func (i *introspectionResponse) Scopes() []string { + return strings.Split(i.Scope, " ") +} diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go index 8df539ba7..3881b8e0c 100644 --- a/setup/mscs/mscs.go +++ b/setup/mscs/mscs.go @@ -16,6 +16,7 @@ import ( "github.com/element-hq/dendrite/setup" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/setup/mscs/msc2836" + "github.com/element-hq/dendrite/setup/mscs/msc3861" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -34,10 +35,11 @@ func Enable(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Rout func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error { switch msc { case "msc2836": - return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing) + return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserVerifierProvider, monolith.KeyRing) case "msc2444": // enabled inside federationapi case "msc2753": // enabled inside clientapi - case "msc3861": // enabled inside clientapi + case "msc3861": + return msc3861.Enable(monolith) default: logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index dcc78c859..484736988 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -36,16 +36,17 @@ func Setup( lazyLoadCache caching.LazyLoadCache, fts fulltext.Indexer, rateLimits *httputil.RateLimits, + userVerifier httputil.UserVerifier, ) { v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() // TODO: Add AS support for all handlers below. - v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r @@ -58,7 +59,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", - httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_get_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -68,7 +69,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", - httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("put_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -78,7 +79,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/user/{userId}/filter/{filterId}", - httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("get_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -87,12 +88,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", - httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("context", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -108,7 +109,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", - httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relations", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -122,7 +123,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", - httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -136,7 +137,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", - httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -150,7 +151,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/search", - httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -173,7 +174,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 2b1dc9958..a45173dbe 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -42,6 +42,7 @@ func AddPublicRoutes( userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, caches caching.LazyLoadCache, + userVerifier httputil.UserVerifier, enableMetrics bool, ) { js, natsClient := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream) @@ -149,5 +150,6 @@ func AddPublicRoutes( routers.Client, requestPool, syncDB, userAPI, rsAPI, &dendriteCfg.SyncAPI, caches, fts, rateLimits, + userVerifier, ) } diff --git a/userapi/api/api.go b/userapi/api/api.go index f0ef26bf1..6899e5e21 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -32,6 +32,8 @@ type UserInternalAPI interface { QuerySearchProfilesAPI // used by p2p demos QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) + QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *QueryLocalpartExternalIDRequest, res *QueryLocalpartExternalIDResponse) (err error) + PerformLocalpartExternalUserIDCreation(ctx context.Context, req *PerformLocalpartExternalUserIDCreationRequest) (err error) } // api functions required by the appservice api @@ -129,6 +131,7 @@ type QuerySearchProfilesAPI interface { QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error } +// FIXME: typo in Acccess // common function for creating authenticated endpoints (used in client/media/sync api) type QueryAcccessTokenAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -316,6 +319,9 @@ type PerformAccountCreationRequest struct { Localpart string // Required: The localpart for this account. Ignored if account type is guest. ServerName spec.ServerName // optional: if not specified, default server name used instead + DisplayName string // optional: this is populated only by MAS. In the legacy flow it's not used + AvatarURL string // optional: this is populated only by MAS. In the legacy flow it's not used + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account OnConflict Conflict @@ -653,10 +659,26 @@ type QueryAccountByLocalpartResponse struct { Account *Account } +type QueryLocalpartExternalIDRequest struct { + ExternalID string + AuthProvider string +} + +type QueryLocalpartExternalIDResponse struct { + LocalpartExternalID *LocalpartExternalID +} + +type PerformLocalpartExternalUserIDCreationRequest struct { + Localpart string + ExternalID string + AuthProvider string +} + // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QueryMasterKeys(ctx context.Context, req *QueryMasterKeysRequest, res *QueryMasterKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) @@ -918,6 +940,16 @@ type QueryKeysResponse struct { Error *KeyError } +type QueryMasterKeysRequest struct { + UserID string +} + +type QueryMasterKeysResponse struct { + Key *types.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index fe5d9f7d9..dfd426c36 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -114,7 +114,9 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeMaster] = req.MasterKey for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeMaster] = key + toStore[fclient.CrossSigningKeyPurposeMaster] = types.CrossSigningKey{ + KeyData: key, + } } hasMasterKey = true } @@ -130,7 +132,9 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeSelfSigning] = key + toStore[fclient.CrossSigningKeyPurposeSelfSigning] = types.CrossSigningKey{ + KeyData: key, + } } } @@ -145,7 +149,9 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeUserSigning] = key + toStore[fclient.CrossSigningKeyPurposeUserSigning] = types.CrossSigningKey{ + KeyData: key, + } } } @@ -198,7 +204,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. changed = true break } - if !bytes.Equal(old, new) { + if !bytes.Equal(old.KeyData, new.KeyData) { // One of the existing keys for a purpose we already knew about has // changed. changed = true @@ -210,7 +216,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. } // Store the keys. - if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore, nil); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 6cb11bcd2..98c387842 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -234,6 +234,19 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) } +func (a *UserInternalAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { + crossSigningKeyMap, err := a.KeyDatabase.CrossSigningKeysDataForUserAndKeyType(ctx, req.UserID, fclient.CrossSigningKeyPurposeMaster) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query user cross signing master keys: %s", err), + } + return + } + if key, ok := crossSigningKeyMap[fclient.CrossSigningKeyPurposeMaster]; ok { + res.Key = &key + } +} + // nolint:gocyclo func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { var respMu sync.Mutex diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 666e75f93..a7760c1b2 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -7,6 +7,7 @@ package internal import ( + "cmp" "context" "database/sql" "encoding/json" @@ -247,10 +248,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + displayName := cmp.Or(req.DisplayName, req.Localpart) + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, displayName); err != nil { return fmt.Errorf("a.DB.SetDisplayName: %w", err) } + if req.AvatarURL != "" { + if _, _, err := a.DB.SetAvatarURL(ctx, req.Localpart, serverName, req.AvatarURL); err != nil { + return fmt.Errorf("a.DB.SetAvatarURL: %w", err) + } + } + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true @@ -594,6 +602,15 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api. return } +func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, req *api.PerformLocalpartExternalUserIDCreationRequest) (err error) { + return a.DB.CreateLocalpartExternalID(ctx, req.Localpart, req.ExternalID, req.AuthProvider) +} + +func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *api.QueryLocalpartExternalIDRequest, res *api.QueryLocalpartExternalIDResponse) (err error) { + res.LocalpartExternalID, err = a.DB.GetLocalpartForExternalID(ctx, req.ExternalID, req.AuthProvider) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 13d8c2013..6dbf97c5e 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -226,9 +226,10 @@ type KeyDatabase interface { CrossSigningKeysForUser(ctx context.Context, userID string) (map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey, error) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap, updatableWithoutUIABeforeMs *int64) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error DeleteStaleDeviceLists( diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index a8566e69b..a8f0d8cb8 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -24,23 +24,29 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type SMALLINT NOT NULL, key_data TEXT NOT NULL, + updatable_without_uia_before_ms BIGINT DEFAULT NULL, PRIMARY KEY (user_id, key_type) ); ` const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" +const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + + "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1 AND key_type = $2" + const upsertCrossSigningKeysForUserSQL = "" + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -53,6 +59,7 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -69,27 +76,64 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - if err = rows.Scan(&keyTypeInt, &keyData); err != nil { + var updatableWithoutUIABeforeMs *int64 + if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = keyData + r[keyType] = types.CrossSigningKey{ + UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, + KeyData: keyData, + } + } + err = rows.Err() + return +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, +) (r types.CrossSigningKeyMap, err error) { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return nil, fmt.Errorf("unknown key purpose %q", keyType) + } + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData spec.Base64Bytes + var updatableWithoutUIABeforeMs *int64 + if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = types.CrossSigningKey{ + UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, + KeyData: keyData, + } } err = rows.Err() return } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { return fmt.Errorf("unknown key purpose %q", keyType) } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData, updatableWithoutUIABeforeMs); err != nil { return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index aade4be1f..80b225c97 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -1101,12 +1101,12 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string } results := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + keyID := gomatrixserverlib.KeyID("ed25519:" + key.KeyData.Encode()) result := fclient.CrossSigningKey{ UserID: userID, Usage: []fclient.CrossSigningKeyPurpose{purpose}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{ - keyID: key, + keyID: key.KeyData, }, } sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) @@ -1137,16 +1137,21 @@ func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID st return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } +// CrossSigningKeysForUserAndKeyType returns the latest known cross-signing keys for a user and key type, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUserAndKeyType(ctx, nil, userID, keyType) +} + // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) } // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap, updatableWithoutUIABeforeMs *int64) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + for keyType, key := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key.KeyData, key.UpdatableWithoutUIABeforeMs); err != nil { return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) } } diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index dd8923d30..9c4d3cb57 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -24,22 +24,28 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type INTEGER NOT NULL, key_data TEXT NOT NULL, + updatable_without_uia_before_ms BIGINT DEFAULT NULL, PRIMARY KEY (user_id, key_type) ); ` const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" +const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + + "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1 AND key_type = $2" + const upsertCrossSigningKeysForUserSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data, updatable_without_uia_before_ms)" + + " VALUES($1, $2, $3, $4)" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -52,6 +58,7 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -68,27 +75,64 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - if err = rows.Scan(&keyTypeInt, &keyData); err != nil { + var updatableWithoutUiaBeforeMs *int64 + if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUiaBeforeMs); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = keyData + r[keyType] = types.CrossSigningKey{ + UpdatableWithoutUIABeforeMs: updatableWithoutUiaBeforeMs, + KeyData: keyData, + } + } + err = rows.Err() + return +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, +) (r types.CrossSigningKeyMap, err error) { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return nil, fmt.Errorf("unknown key purpose %q", keyType) + } + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData spec.Base64Bytes + var updatableWithoutUIABeforeMs *int64 + if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = types.CrossSigningKey{ + UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, + KeyData: keyData, + } } err = rows.Err() return } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { return fmt.Errorf("unknown key purpose %q", keyType) } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData, updatableWithoutUIABeforeMs); err != nil { return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 7b141629a..cfd1e571a 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -133,10 +133,6 @@ type LocalpartExternalIDsTable interface { Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error } -type UIAuthSessionsTable interface { - SelectByID(ctx context.Context, txn *sql.Tx, sessionID int) (*api.UIAuthSession, error) -} - type NotificationFilter uint32 const ( @@ -202,7 +198,8 @@ type StaleDeviceLists interface { type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error + SelectCrossSigningKeysForUserAndKeyType(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64) error } type CrossSigningSigs interface { diff --git a/userapi/types/storage.go b/userapi/types/storage.go index 971f3dc9a..260e2a251 100644 --- a/userapi/types/storage.go +++ b/userapi/types/storage.go @@ -37,8 +37,13 @@ var KeyTypeIntToPurpose = map[int16]fclient.CrossSigningKeyPurpose{ 3: fclient.CrossSigningKeyPurposeUserSigning, } +type CrossSigningKey struct { + UpdatableWithoutUIABeforeMs *int64 + KeyData spec.Base64Bytes +} + // Map of purpose -> public key -type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]spec.Base64Bytes +type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]CrossSigningKey // Map of user ID -> key ID -> signature type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes From 9d9841d02e1feabe339a47a79a5116023f0b1ce1 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 02:11:30 +0000 Subject: [PATCH 06/71] mas: added "admin's replacement without uia" endpoint i.e. /_synapse/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia --- clientapi/routing/admin.go | 53 +++++++++++++++++++ clientapi/routing/key_crosssigning.go | 40 +++++++------- clientapi/routing/routing.go | 5 ++ userapi/api/api.go | 14 +++++ userapi/internal/cross_signing.go | 10 ++++ userapi/storage/interface.go | 2 + .../postgres/cross_signing_keys_table.go | 29 ++++++++-- userapi/storage/shared/storage.go | 14 ++++- .../sqlite3/cross_signing_keys_table.go | 29 ++++++++-- userapi/storage/tables/interface.go | 1 + 10 files changed, 168 insertions(+), 29 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 0b07724ab..ce5476ef4 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -2,6 +2,7 @@ package routing import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -30,6 +31,8 @@ import ( userapi "github.com/element-hq/dendrite/userapi/api" ) +const replacementPeriod = 10 * time.Minute + var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { @@ -607,6 +610,56 @@ func AdminHandleUserDeviceByUserID( } +func AdminAllowCrossSigningReplacementWithoutUIA( + req *http.Request, + userAPI userapi.ClientUserAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userIDstr, ok := vars["userID"] + userID, err := spec.NewUserID(userIDstr, false) + if !ok || err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.MissingParam("User not found."), + } + } + + switch req.Method { + case http.MethodPost: + rq := userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest{ + UserID: userID.String(), + Duration: replacementPeriod, + } + var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse + err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.MissingParam("User has no master cross-signing key"), + } + } else if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[string]int64{"updatable_without_uia_before_ms": rs.Timestamp}, + } + default: + return util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: spec.Unknown("Method not allowed."), + } + } + +} + type adminExternalID struct { AuthProvider string `json:"auth_provider"` ExternalID string `json:"external_id"` diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7bcd7093c..78b66400b 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -176,25 +176,25 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie func convertKeyError(err *api.KeyError) util.JSONResponse { switch { - case err.IsInvalidSignature: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidSignature(err.Error()), - } - case err.IsMissingParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.MissingParam(err.Error()), - } - case err.IsInvalidParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidParam(err.Error()), - } - default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(err.Error()), - } + case err.IsInvalidSignature: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidSignature(err.Error()), + } + case err.IsMissingParam: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam(err.Error()), } + case err.IsInvalidParam: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(err.Error()), + } + default: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(err.Error()), + } + } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index ed93d0796..945d0e48e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -352,6 +352,11 @@ func Setup( httputil.MakeServiceAdminAPI("admin_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminHandleUserDeviceByUserID(r, userAPI) })).Methods(http.MethodPost, http.MethodGet) + + synapseAdminRouter.Handle("/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia", + httputil.MakeServiceAdminAPI("admin_allow_cross_signing_replacement_without_uia", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminAllowCrossSigningReplacementWithoutUIA(r, userAPI) + })).Methods(http.MethodPost) } if mscCfg.Enabled("msc2753") { diff --git a/userapi/api/api.go b/userapi/api/api.go index 6899e5e21..bcd5c9c0b 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -680,6 +680,11 @@ type ClientKeyAPI interface { QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryMasterKeys(ctx context.Context, req *QueryMasterKeysRequest, res *QueryMasterKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA( + ctx context.Context, + req *PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest, + res *PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse, + ) error PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages @@ -908,6 +913,15 @@ type PerformUploadDeviceKeysResponse struct { Error *KeyError } +type PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest struct { + UserID string + Duration time.Duration +} + +type PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse struct { + Timestamp int64 +} + type PerformUploadDeviceSignaturesRequest struct { Signatures map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice // The user that uploaded the sig, should be populated by the clientapi. diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index dfd426c36..a93fba150 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -96,6 +96,16 @@ func sanityCheckKey(key fclient.CrossSigningKey, userID string, purpose fclient. return nil } +func (a *UserInternalAPI) PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA( + ctx context.Context, + req *api.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest, + res *api.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse, +) error { + var err error + res.Timestamp, err = a.KeyDatabase.UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx, req.UserID, req.Duration) + return err +} + // nolint:gocyclo func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Find the keys to store. diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 6dbf97c5e..11b360952 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "errors" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -231,6 +232,7 @@ type KeyDatabase interface { StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap, updatableWithoutUIABeforeMs *int64) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error + UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, userID string, duration time.Duration) (int64, error) DeleteStaleDeviceLists( ctx context.Context, diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index a8f0d8cb8..7e66a0114 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "fmt" + "time" "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" @@ -42,11 +43,17 @@ const upsertCrossSigningKeysForUserSQL = "" + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" +const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + + "UPDATE keyserver_cross_signing_keys" + + " SET updatable_without_uia_before_ms = $3" + + " WHERE user_id = $1 AND key_type = $2" + type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt + updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -61,6 +68,7 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + {&s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt, updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL}, }.Prepare(db) } @@ -138,3 +146,16 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( } return nil } + +func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { + keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] + ts := time.Now().Add(duration).UnixMilli() + result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, userID, keyTypeInt, ts) + if err != nil { + return -1, err + } + if n, _ := result.RowsAffected(); n == 0 { + return -1, sql.ErrNoRows + } + return ts, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 80b225c97..2feca052d 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -1159,7 +1159,19 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s }) } -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +// UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA updates the 'updatable_without_uia_before_ms' attribute of the master cross-signing key. +// Normally this attribute depending on its value marks the master key as replaceable without UIA. +func (d *KeyDatabase) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, userID string, duration time.Duration) (int64, error) { + var ts int64 + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + ts, err = d.CrossSigningKeysTable.UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx, txn, userID, duration) + return err + }) + return ts, err +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/device. func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index 9c4d3cb57..47a39f6b1 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "fmt" + "time" "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" @@ -41,11 +42,17 @@ const upsertCrossSigningKeysForUserSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data, updatable_without_uia_before_ms)" + " VALUES($1, $2, $3, $4)" +const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + + "UPDATE keyserver_cross_signing_keys" + + " SET updatable_without_uia_before_ms = $3" + + " WHERE user_id = $1 AND key_type = $2" + type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt + updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -60,6 +67,7 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + {&s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt, updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL}, }.Prepare(db) } @@ -137,3 +145,16 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( } return nil } + +func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { + keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] + ts := time.Now().Add(duration).UnixMilli() + result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, userID, keyTypeInt, ts) + if err != nil { + return -1, err + } + if n, _ := result.RowsAffected(); n == 0 { + return -1, sql.ErrNoRows + } + return ts, nil +} diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index cfd1e571a..8e629914e 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -200,6 +200,7 @@ type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) SelectCrossSigningKeysForUserAndKeyType(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose) (r types.CrossSigningKeyMap, err error) UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64) error + UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) } type CrossSigningSigs interface { From 4f406e262ae5375009767c1ac8b7c0c31a45ba0b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 17:04:28 +0000 Subject: [PATCH 07/71] minor goimports fix --- federationapi/storage/storage_wasm.go | 2 +- relayapi/storage/storage_wasm.go | 2 +- userapi/storage/storage_wasm.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index 9f630f37d..10ed7d2a1 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -14,7 +14,7 @@ import ( "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database diff --git a/relayapi/storage/storage_wasm.go b/relayapi/storage/storage_wasm.go index 86ba972a9..69f4fa174 100644 --- a/relayapi/storage/storage_wasm.go +++ b/relayapi/storage/storage_wasm.go @@ -13,7 +13,7 @@ import ( "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/relayapi/storage/sqlite3" "github.com/element-hq/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 3a2afdf06..309aecd66 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -14,7 +14,7 @@ import ( "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/storage/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func NewUserDatabase( From b95070393e01d41a1b8e333e5273e17802b0d3cb Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 17:05:27 +0000 Subject: [PATCH 08/71] mas: return correct http code If access token expires the client(i.e. element) expects a specific response with http code 401 and spec.UnknownToken --- setup/mscs/msc3861/msc3861_user_verifier.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index fcf5bb396..dab71d457 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -102,8 +102,8 @@ func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev } case codeInvalidClientToken: return nil, &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Forbidden(e.Error()), + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken(e.Error()), } case codeAuthError: return nil, &util.JSONResponse{ From 9ebcebee431117e2be0b7f13d2872e9bf449ffb8 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 17:12:22 +0000 Subject: [PATCH 09/71] another goimports fix --- setup/mscs/msc3861/msc3861_user_verifier.go | 7 ++----- userapi/storage/shared/storage.go | 2 +- userapi/types/storage.go | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index dab71d457..597b844e2 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -322,8 +322,6 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st logger.Debugf("PerformDeviceCreationResponse is: %+v", rs) } - - return &requester{ Device: device, UserID: userID, @@ -358,7 +356,7 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) defer resp.Body.Close() if c := resp.StatusCode; c < 200 || c >= 300 { - return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status ,"' response"}, "")) + return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) } var ir introspectionResponse if err := json.NewDecoder(body).Decode(&ir); err != nil { @@ -402,8 +400,7 @@ type OpenIDConfiguration struct { AccountManagementActionsSupported []string `json:"account_management_actions_supported"` } -func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (* - OpenIDConfiguration, error) { +func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (*OpenIDConfiguration, error) { u, err := url.Parse(authHostURL) if err != nil { return nil, err diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 2feca052d..e76cde61b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -1163,7 +1163,7 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s // Normally this attribute depending on its value marks the master key as replaceable without UIA. func (d *KeyDatabase) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, userID string, duration time.Duration) (int64, error) { var ts int64 - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error ts, err = d.CrossSigningKeysTable.UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx, txn, userID, duration) return err diff --git a/userapi/types/storage.go b/userapi/types/storage.go index 260e2a251..3e58c8e67 100644 --- a/userapi/types/storage.go +++ b/userapi/types/storage.go @@ -39,7 +39,7 @@ var KeyTypeIntToPurpose = map[int16]fclient.CrossSigningKeyPurpose{ type CrossSigningKey struct { UpdatableWithoutUIABeforeMs *int64 - KeyData spec.Base64Bytes + KeyData spec.Base64Bytes } // Map of purpose -> public key From be8d490e56002b45a1e97caa62435a9826f8005b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 17:14:04 +0000 Subject: [PATCH 10/71] mas: implemented PUT /admin/v2/users/{userID} endpoint MAS requires this endpoint to fetch the data for the account management page --- clientapi/routing/admin.go | 72 ++++++++++++++++++++++ clientapi/routing/routing.go | 13 +++- userapi/api/api.go | 11 ++-- userapi/storage/postgres/accounts_table.go | 6 +- userapi/storage/sqlite3/accounts_table.go | 5 +- 5 files changed, 95 insertions(+), 12 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index ce5476ef4..0532a5775 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -21,8 +21,10 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/exp/constraints" + appserviceAPI "github.com/element-hq/dendrite/appservice/api" clientapi "github.com/element-hq/dendrite/clientapi/api" clienthttputil "github.com/element-hq/dendrite/clientapi/httputil" + "github.com/element-hq/dendrite/clientapi/userutil" "github.com/element-hq/dendrite/internal/httputil" roomserverAPI "github.com/element-hq/dendrite/roomserver/api" "github.com/element-hq/dendrite/setup/config" @@ -731,6 +733,76 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI } } +func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, ok := vars["userID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.MissingParam("Expecting user ID."), + } + } + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam(err.Error()), + } + } + + body := struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` + Deactivated bool `json:"deactivated"` + }{} + + { + var rs api.QueryAccountByLocalpartResponse + err := userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + } + } else if err != nil { + logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + body.Deactivated = rs.Account.Deactivated + } + + { + profile, err := userAPI.QueryProfile(req.Context(), userID) + if err != nil { + if err == appserviceAPI.ErrProfileNotExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(err.Error()), + } + } else if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + } + body.AvatarURL = profile.AvatarURL + body.DisplayName = profile.DisplayName + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: body, + } +} + // GetEventReports returns reported events for a given user/room. func GetEventReports( req *http.Request, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 945d0e48e..142b1f814 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -344,9 +344,16 @@ func Setup( })).Methods(http.MethodGet) synapseAdminRouter.Handle("/admin/v2/users/{userID}", - httputil.MakeServiceAdminAPI("admin_provision_user", m.AdminToken, func(r *http.Request) util.JSONResponse { - return AdminCreateOrModifyAccount(r, userAPI) - })).Methods(http.MethodPut) + httputil.MakeServiceAdminAPI("admin_manage_user", m.AdminToken, func(r *http.Request) util.JSONResponse { + switch r.Method { + case http.MethodGet: + return AdminRetrieveAccount(r, cfg, userAPI) + case http.MethodPut: + return AdminCreateOrModifyAccount(r, userAPI) + default: + return util.JSONResponse{Code: http.StatusMethodNotAllowed, JSON: nil} + } + })).Methods(http.MethodPut, http.MethodGet) synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", httputil.MakeServiceAdminAPI("admin_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { diff --git a/userapi/api/api.go b/userapi/api/api.go index bcd5c9c0b..5387276b7 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -31,7 +31,6 @@ type UserInternalAPI interface { FederationUserAPI QuerySearchProfilesAPI // used by p2p demos - QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *QueryLocalpartExternalIDRequest, res *QueryLocalpartExternalIDResponse) (err error) PerformLocalpartExternalUserIDCreation(ctx context.Context, req *PerformLocalpartExternalUserIDCreationRequest) (err error) } @@ -89,6 +88,7 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) @@ -461,6 +461,7 @@ type Account struct { ServerName spec.ServerName AppServiceID string AccountType AccountType + Deactivated bool // TODO: Associations (e.g. with application services) } @@ -660,7 +661,7 @@ type QueryAccountByLocalpartResponse struct { } type QueryLocalpartExternalIDRequest struct { - ExternalID string + ExternalID string AuthProvider string } @@ -669,8 +670,8 @@ type QueryLocalpartExternalIDResponse struct { } type PerformLocalpartExternalUserIDCreationRequest struct { - Localpart string - ExternalID string + Localpart string + ExternalID string AuthProvider string } @@ -914,7 +915,7 @@ type PerformUploadDeviceKeysResponse struct { } type PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest struct { - UserID string + UserID string Duration time.Duration } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 489017fb9..5c0519962 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -55,7 +55,7 @@ const deactivateAccountSQL = "" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" + "SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" @@ -116,6 +116,7 @@ func (s *accountsStatements) InsertAccount( localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { + // TODO: can we replace "UnixNano() / 1M" with "UnixMilli()"? createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) @@ -135,6 +136,7 @@ func (s *accountsStatements) InsertAccount( ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, + Deactivated: false, }, nil } @@ -167,7 +169,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 66cc7c060..7c6279196 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -54,7 +54,7 @@ const deactivateAccountSQL = "" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" + "SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" @@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount( ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, + Deactivated: false, }, nil } @@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") From 524f65cb0c036c7128a7f47a12c494b85c3a07b0 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 19:50:13 +0000 Subject: [PATCH 11/71] mas: add AccountTypeOIDCService --- setup/mscs/msc3861/msc3861_user_verifier.go | 5 ++--- userapi/api/api.go | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 597b844e2..c0f4342b9 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -127,7 +127,7 @@ func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev // Do not record requests from MAS using the virtual `__oidc_admin` user. if token != m.cfg.AdminToken { - // TODO: not sure which exact data we should record here. See the link for reference + // XXX: not sure which exact data we should record here. See the link for reference // https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/base.py#L365 } @@ -156,7 +156,6 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st // XXX: This is a temporary solution so that the admin API can be called by // the OIDC provider. This will be removed once we have OIDC client // credentials grant support in matrix-authentication-service. - logger.Info("Admin token used") // XXX: that user doesn't exist and won't be provisioned. adminUser, err := createUserID("__oidc_admin", m.serverName) if err != nil { @@ -165,7 +164,7 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st return &requester{ UserID: adminUser, Scope: []string{"urn:synapse:admin:*"}, - Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeAdmin}, + Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeOIDCService}, }, nil } diff --git a/userapi/api/api.go b/userapi/api/api.go index 5387276b7..2efa8976e 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -529,6 +529,8 @@ const ( AccountTypeAdmin AccountType = 3 // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 + // AccountTypeOIDC indicates this is an account belonging to Matrix Authentication Service (MAS) + AccountTypeOIDCService AccountType = 5 ) type QueryPushersRequest struct { From ff63e7fa983386441709d84fc9df6ac33fbfaa59 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 30 Dec 2024 20:31:10 +0000 Subject: [PATCH 12/71] mas: modify PUT /profile/{userID}/displayname endpoint Extended logic of the endpoint in order to make it compatible with MAS --- clientapi/routing/admin.go | 48 +++++++++++++++++------------------- clientapi/routing/profile.go | 38 ++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 0532a5775..c51ea7e98 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -760,42 +760,38 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user Deactivated bool `json:"deactivated"` }{} - { - var rs api.QueryAccountByLocalpartResponse - err := userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) - if err == sql.ErrNoRows { + var rs api.QueryAccountByLocalpartResponse + err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + } + } else if err != nil { + logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } + body.Deactivated = rs.Account.Deactivated + + profile, err := userAPI.QueryProfile(req.Context(), userID) + if err != nil { + if err == appserviceAPI.ErrProfileNotExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)), + JSON: spec.NotFound(err.Error()), } } else if err != nil { - logger.WithError(err).Error("userAPI.QueryAccountByLocalpart") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown(err.Error()), } } - body.Deactivated = rs.Account.Deactivated - } - - { - profile, err := userAPI.QueryProfile(req.Context(), userID) - if err != nil { - if err == appserviceAPI.ErrProfileNotExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(err.Error()), - } - } else if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown(err.Error()), - } - } - } - body.AvatarURL = profile.AvatarURL - body.DisplayName = profile.DisplayName } + body.AvatarURL = profile.AvatarURL + body.DisplayName = profile.DisplayName return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index b75d38a62..74bbddbc7 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -172,24 +172,20 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, profileAPI userapi.ProfileAPI, + req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, ) util.JSONResponse { - if userID != device.UserID { + if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("userID does not match the current user"), } } - var r eventutil.UserProfile - if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { - return *resErr - } - + logger := util.GetLogger(req.Context()) localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + logger.WithError(err).Error("gomatrixserverlib.SplitID failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, @@ -203,6 +199,28 @@ func SetDisplayName( } } + if device.AccountType == userapi.AccountTypeOIDCService { + // When a request is made on behalf of an OIDC provider service, the original device object refers + // to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue, + // we need to replace the admin's device with the user's device + var rs userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if len(rs.Devices) > 0 { + device = &rs.Devices[0] + } + } + + var r eventutil.UserProfile + if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { + return *resErr + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -211,9 +229,9 @@ func SetDisplayName( } } - profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) + profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") + logger.WithError(err).Error("profileAPI.SetDisplayName failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, From bf310d558f267d107e541f0fa896f1bccf54b487 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 31 Dec 2024 01:40:14 +0000 Subject: [PATCH 13/71] drop primary key constraint from userapi_devices.access_token --- setup/mscs/msc3861/msc3861_user_verifier.go | 5 +- ...23101250000_drop_primary_key_constraint.go | 25 +++++++ userapi/storage/postgres/devices_table.go | 14 ++-- ...23101150000_drop_primary_key_constraint.go | 65 +++++++++++++++++++ userapi/storage/sqlite3/devices_table.go | 14 ++-- 5 files changed, 111 insertions(+), 12 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go create mode 100644 userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index c0f4342b9..cb73106f4 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -283,8 +283,6 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st Msg: strings.Join([]string{"Invalid device ID in scope: ", deviceID}, ""), } } - logger.Debugf("deviceID is: %s", deviceID) - logger.Debugf("scope is: %+v", scopes) userDeviceExists := false { @@ -302,14 +300,13 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } } } - logger.Debugf("userDeviceExists is: %t", userDeviceExists) if !userDeviceExists { var rs api.PerformDeviceCreationResponse deviceDisplayName := "OIDC-native client" if err := m.userAPI.PerformDeviceCreation(ctx, &api.PerformDeviceCreationRequest{ Localpart: localpart, ServerName: m.serverName, - AccessToken: token, + AccessToken: "", DeviceID: &deviceID, DeviceDisplayName: &deviceDisplayName, // TODO: Cannot add IPAddr and Useragent values here. Should we care about it here? diff --git a/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go new file mode 100644 index 000000000..0bf7d3863 --- /dev/null +++ b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go @@ -0,0 +1,25 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` +ALTER TABLE userapi_devices DROP CONSTRAINT userapi_devices_pkey;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE userapi_devices ADD CONSTRAINT userapi_devices_pkey PRIMARY KEY (access_token);`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index b5feea07f..e76452447 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -116,10 +116,16 @@ func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Dev return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: add last_seen_ts", - Up: deltas.UpLastSeenTSIP, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add last_seen_ts", + Up: deltas.UpLastSeenTSIP, + }, + sqlutil.Migration{ + Version: "userapi: drop primary key constraint", + Up: deltas.UpDropPrimaryKeyConstraint, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go new file mode 100644 index 000000000..def7a75e2 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go @@ -0,0 +1,65 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; + CREATE TABLE userapi_devices ( + access_token TEXT, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + server_name TEXT NOT NULL, + created_ts BIGINT, + display_name TEXT, + last_seen_ts BIGINT, + ip TEXT, + user_agent TEXT, + UNIQUE (localpart, device_id) + ); + INSERT + INTO userapi_devices ( + access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent + ) SELECT + access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` +ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp; +CREATE TABLE userapi_devices ( + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + server_name TEXT NOT NULL, + created_ts BIGINT, + display_name TEXT, + last_seen_ts BIGINT, + ip TEXT, + user_agent TEXT, + UNIQUE (localpart, device_id) + ); + INSERT + INTO userapi_devices ( + access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent + ) SELECT + access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' + FROM userapi_devices_tmp; + DROP TABLE userapi_devices_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index d5d1fed3d..2eb88109a 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -102,10 +102,16 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Devic return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: add last_seen_ts", - Up: deltas.UpLastSeenTSIP, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add last_seen_ts", + Up: deltas.UpLastSeenTSIP, + }, + sqlutil.Migration{ + Version: "userapi: drop primary key constraint", + Up: deltas.UpDropPrimaryKeyConstraint, + }, + ) if err = m.Up(context.Background()); err != nil { return nil, err } From f4ff4266f5cfc7d4ef99eb89884f6d068ae62aea Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 1 Jan 2025 00:35:06 +0000 Subject: [PATCH 14/71] mas: refactor admin user device handler --- clientapi/routing/admin.go | 67 ++++++++++++++++++++++++++---------- clientapi/routing/routing.go | 2 +- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index c51ea7e98..588d779e8 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -523,27 +523,27 @@ func AdminCheckUsernameAvailable( } } -func AdminHandleUserDeviceByUserID( +func AdminHandleUserDeviceRetrievingCreation( req *http.Request, userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, ) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, ok := vars["userID"] - if !ok { + userID, _ := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.MissingParam("Expecting user ID."), + JSON: spec.InvalidParam(err.Error()), } } - logger := util.GetLogger(req.Context()) switch req.Method { case http.MethodPost: - local, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -557,21 +557,50 @@ func AdminHandleUserDeviceByUserID( return *resErr } - var rs userapi.PerformDeviceCreationResponse - if err := userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ - Localpart: local, - ServerName: domain, - DeviceID: &payload.DeviceID, - IPAddr: "", - UserAgent: req.UserAgent(), - NoDeviceListUpdate: false, - FromRegistration: false, - }, &rs); err != nil { - logger.WithError(err).Debug("PerformDeviceCreation failed") - return util.MessageResponse(http.StatusBadRequest, err.Error()) + userDeviceExists := false + { + var rs api.QueryDevicesResponse + if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + logger.WithError(err).Error("QueryDevices") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if !rs.UserExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Given user ID does not exist"), + } + } + for i := range rs.Devices { + logger.Errorf("%s :: %s", rs.Devices[0].ID, payload.DeviceID) + if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID { + userDeviceExists = true + break + } + } } - logger.WithError(err).Debug("PerformDeviceCreation succeeded") + if !userDeviceExists { + var rs userapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: local, + ServerName: domain, + DeviceID: &payload.DeviceID, + IPAddr: "", + UserAgent: req.UserAgent(), + NoDeviceListUpdate: false, + FromRegistration: false, + }, &rs); err != nil { + logger.WithError(err).Error("PerformDeviceCreation") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + logger.WithError(err).Debug("PerformDeviceCreation succeeded") + } return util.JSONResponse{ Code: http.StatusCreated, JSON: struct{}{}, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 142b1f814..182c8a47c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -357,7 +357,7 @@ func Setup( synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", httputil.MakeServiceAdminAPI("admin_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { - return AdminHandleUserDeviceByUserID(r, userAPI) + return AdminHandleUserDeviceRetrievingCreation(r, userAPI, cfg) })).Methods(http.MethodPost, http.MethodGet) synapseAdminRouter.Handle("/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia", From 803cce882f25dbdb4a21956ec7037994e921bf65 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 1 Jan 2025 02:26:15 +0000 Subject: [PATCH 15/71] mas: added admin's delete devices endpoint --- clientapi/routing/admin.go | 120 +++++++++++++++++++++++++++++++++-- clientapi/routing/routing.go | 12 +++- 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 588d779e8..b3272b0a6 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -33,9 +33,14 @@ import ( userapi "github.com/element-hq/dendrite/userapi/api" ) -const replacementPeriod = 10 * time.Minute +const ( + replacementPeriod time.Duration = 10 * time.Minute +) -var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") +var ( + validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + deviceDisplayName = "OIDC-native client" +) func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { @@ -523,7 +528,7 @@ func AdminCheckUsernameAvailable( } } -func AdminHandleUserDeviceRetrievingCreation( +func AdminUserDeviceRetrieveCreate( req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, @@ -574,7 +579,6 @@ func AdminHandleUserDeviceRetrievingCreation( } } for i := range rs.Devices { - logger.Errorf("%s :: %s", rs.Devices[0].ID, payload.DeviceID) if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID { userDeviceExists = true break @@ -588,6 +592,7 @@ func AdminHandleUserDeviceRetrievingCreation( Localpart: local, ServerName: domain, DeviceID: &payload.DeviceID, + DeviceDisplayName: &deviceDisplayName, IPAddr: "", UserAgent: req.UserAgent(), NoDeviceListUpdate: false, @@ -638,7 +643,114 @@ func AdminHandleUserDeviceRetrievingCreation( JSON: struct{}{}, } } +} + +func AdminUserDeviceDelete( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, _ := vars["userID"] + deviceID, _ := vars["deviceID"] + logger := util.GetLogger(req.Context()) + + // XXX: we probably have to delete session from the sessions dict + // like we do in DeleteDeviceById. If so, we have to fi + var device *api.Device + { + var rs api.QueryDevicesResponse + if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + logger.WithError(err).Error("userAPI.QueryDevices failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if !rs.UserExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Given user ID does not exist"), + } + } + for i := range rs.Devices { + if d := rs.Devices[i]; d.ID == deviceID && d.UserID == userID { + device = &d + break + } + } + } + + { + // XXX: this response struct can completely removed everywhere as it doesn't + // have any functional purpose + var res api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{device.ID}, + }, &res); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func AdminUserDevicesDelete( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, _ := vars["userID"] + + var payload struct { + Devices []string `json:"devices"` + } + + defer req.Body.Close() + if err = json.NewDecoder(req.Body).Decode(&payload); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to decode device deletion request") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + { + // XXX: this response struct can completely removed everywhere as it doesn't + // have any functional purpose + var rs api.PerformDeviceDeletionResponse + if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{ + UserID: userID, + DeviceIDs: payload.Devices, + }, &rs); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } } func AdminAllowCrossSigningReplacementWithoutUIA( diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 182c8a47c..651a863dd 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -356,9 +356,17 @@ func Setup( })).Methods(http.MethodPut, http.MethodGet) synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", - httputil.MakeServiceAdminAPI("admin_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { - return AdminHandleUserDeviceRetrievingCreation(r, userAPI, cfg) + httputil.MakeServiceAdminAPI("admin_create_retrieve_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDeviceRetrieveCreate(r, userAPI, cfg) })).Methods(http.MethodPost, http.MethodGet) + synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices/{deviceID}", + httputil.MakeServiceAdminAPI("admin_delete_user_device", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDeviceDelete(r, userAPI, cfg) + })).Methods(http.MethodDelete) + synapseAdminRouter.Handle("/admin/v2/users/{userID}/delete_devices", + httputil.MakeServiceAdminAPI("admin_delete_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminUserDevicesDelete(r, userAPI, cfg) + })).Methods(http.MethodPost) synapseAdminRouter.Handle("/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia", httputil.MakeServiceAdminAPI("admin_allow_cross_signing_replacement_without_uia", m.AdminToken, func(r *http.Request) util.JSONResponse { From 7ffb2c1d81ad195093b1b6b7cd611264cbd408b3 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 5 Jan 2025 02:22:56 +0000 Subject: [PATCH 16/71] mas: minor fixes in cross_signing_keys_table files --- userapi/storage/postgres/cross_signing_keys_table.go | 6 +++--- userapi/storage/sqlite3/cross_signing_keys_table.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 7e66a0114..66ccc9a33 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -45,8 +45,8 @@ const upsertCrossSigningKeysForUserSQL = "" + const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + "UPDATE keyserver_cross_signing_keys" + - " SET updatable_without_uia_before_ms = $3" + - " WHERE user_id = $1 AND key_type = $2" + " SET updatable_without_uia_before_ms = $1" + + " WHERE user_id = $2 AND key_type = $3" type crossSigningKeysStatements struct { db *sql.DB @@ -150,7 +150,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] ts := time.Now().Add(duration).UnixMilli() - result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, userID, keyTypeInt, ts) + result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) if err != nil { return -1, err } diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index 47a39f6b1..c86cf2063 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -44,8 +44,8 @@ const upsertCrossSigningKeysForUserSQL = "" + const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + "UPDATE keyserver_cross_signing_keys" + - " SET updatable_without_uia_before_ms = $3" + - " WHERE user_id = $1 AND key_type = $2" + " SET updatable_without_uia_before_ms = $1" + + " WHERE user_id = $2 AND key_type = $3" type crossSigningKeysStatements struct { db *sql.DB @@ -149,7 +149,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] ts := time.Now().Add(duration).UnixMilli() - result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, userID, keyTypeInt, ts) + result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) if err != nil { return -1, err } From c06e0aa2065d580d2f78caf89203282a5a6bf23b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 5 Jan 2025 02:24:40 +0000 Subject: [PATCH 17/71] refactor logger calls --- clientapi/routing/admin.go | 4 ++-- setup/mscs/msc3861/msc3861_user_verifier.go | 1 - userapi/internal/cross_signing.go | 5 +++-- userapi/internal/key_api.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index b3272b0a6..d9ef58827 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -724,7 +724,7 @@ func AdminUserDevicesDelete( defer req.Body.Close() if err = json.NewDecoder(req.Body).Decode(&payload); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to decode device deletion request") + logger.WithError(err).Error("unable to decode device deletion request") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, @@ -860,7 +860,7 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI DisplayName: r.DisplayName, }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Debugln("Failed creating account") + logger.WithError(err).Debugln("Failed creating account") return util.MessageResponse(http.StatusBadRequest, err.Error()) } if res.AccountCreated { diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index cb73106f4..59a4c3aa1 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -173,7 +173,6 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st logger.WithError(err).Error("MSC3861UserVerifier:introspectToken") return nil, err } - logger.Debugf("Introspection result: %+v", *introspectionResult) if !introspectionResult.Active { return nil, &mscError{Code: codeInvalidClientToken, Msg: "Token is not active"} diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index a93fba150..b638f59c0 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -471,10 +471,11 @@ func (a *UserInternalAPI) processOtherSignatures( func (a *UserInternalAPI) crossSigningKeysFromDatabase( ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, ) { + logger := logrus.WithContext(ctx) for targetUserID := range req.UserToDevices { keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil { - logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) + logger.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) continue } @@ -487,7 +488,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) + logger.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) continue } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 98c387842..7ffaf10d7 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -285,7 +285,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque DeviceIDs: dids, }, &queryRes) if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + util.GetLogger(ctx).WithError(err).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") } if res.DeviceKeys[userID] == nil { From 48f3cd3367f3684566c5ad2596c1cb860a32bd37 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 5 Jan 2025 02:35:37 +0000 Subject: [PATCH 18/71] mas: added /admin/v1/deactivate/{userID} endpoint --- clientapi/routing/admin.go | 36 ++++++++++++++++++++++++++++++++++++ clientapi/routing/routing.go | 5 ++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index d9ef58827..be865c39f 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -753,6 +753,42 @@ func AdminUserDevicesDelete( } } +func AdminDeactivateAccount( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, +) util.JSONResponse { + logger := util.GetLogger(req.Context()) + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + userID, _ := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) + if err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + // TODO: "erase" field must also be processed here + // see https://github.com/element-hq/synapse/blob/develop/docs/admin_api/user_admin_api.md#deactivate-account + + var rs api.PerformAccountDeactivationResponse + if err := userAPI.PerformAccountDeactivation(req.Context(), &api.PerformAccountDeactivationRequest{ + Localpart: local, ServerName: domain, + }, &rs); err != nil { + logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + func AdminAllowCrossSigningReplacementWithoutUIA( req *http.Request, userAPI userapi.ClientUserAPI, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 651a863dd..0e672f506 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -342,7 +342,10 @@ func Setup( httputil.MakeServiceAdminAPI("admin_username_available", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminCheckUsernameAvailable(r, userAPI, cfg) })).Methods(http.MethodGet) - + synapseAdminRouter.Handle("/admin/v1/deactivate/{userID}", + httputil.MakeServiceAdminAPI("admin_deactivate_user", m.AdminToken, func(r *http.Request) util.JSONResponse { + return AdminDeactivateAccount(r, userAPI, cfg) + })).Methods(http.MethodPost) synapseAdminRouter.Handle("/admin/v2/users/{userID}", httputil.MakeServiceAdminAPI("admin_manage_user", m.AdminToken, func(r *http.Request) util.JSONResponse { switch r.Method { From 9b064b1572b654699c76d5aa59fc69e6900753d3 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 5 Jan 2025 02:38:07 +0000 Subject: [PATCH 19/71] minor refactoring --- clientapi/routing/admin.go | 19 ++++--------------- clientapi/routing/routing.go | 2 +- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index be865c39f..fcb167e22 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -814,12 +814,7 @@ func AdminAllowCrossSigningReplacementWithoutUIA( } var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs) - if err == sql.ErrNoRows { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.MissingParam("User has no master cross-signing key"), - } - } else if err != nil { + if err != nil && err != sql.ErrNoRows { util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -858,20 +853,14 @@ type adminCreateOrModifyAccountRequest struct { // Locked bool `json:"locked"` } -func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI) util.JSONResponse { +func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse { logger := util.GetLogger(req.Context()) vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, ok := vars["userID"] - if !ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.MissingParam("Expecting user ID."), - } - } - local, domain, err := gomatrixserverlib.SplitID('@', userID) + userID, _ := vars["userID"] + local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 0e672f506..55a660fe5 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -352,7 +352,7 @@ func Setup( case http.MethodGet: return AdminRetrieveAccount(r, cfg, userAPI) case http.MethodPut: - return AdminCreateOrModifyAccount(r, userAPI) + return AdminCreateOrModifyAccount(r, userAPI, cfg) default: return util.JSONResponse{Code: http.StatusMethodNotAllowed, JSON: nil} } From cc7deb22ad04fec064a13de84b2ca4c932707fb8 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sun, 5 Jan 2025 02:42:20 +0000 Subject: [PATCH 20/71] mas: added support of msc3861 to /keys/device_signing/upload endpoint this change is based mostly on changes made in synapse https://github.com/element-hq/synapse/blob/develop/synapse/rest/client/keys.py#L392 --- clientapi/routing/key_crosssigning.go | 39 ++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 78b66400b..d3950e75c 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -8,6 +8,7 @@ package routing import ( "net/http" + "slices" "strings" "time" @@ -16,6 +17,7 @@ import ( "github.com/element-hq/dendrite/clientapi/httputil" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,6 +60,41 @@ func UploadCrossSigningDeviceKeys( } } + { + var keysResp api.QueryKeysResponse + keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{UserID: device.UserID, UserToDevices: map[string][]string{device.UserID: []string{}}}, &keysResp) + if err := keysResp.Error; err != nil { + return convertKeyError(err) + } + hasDifferentKeys := func(userID string, uploadReqCSKey *fclient.CrossSigningKey, dbCSKeys map[string]fclient.CrossSigningKey) bool { + dbCSKey, ok := dbCSKeys[userID] + if !ok { + return true + } + dbKeysExist := len(dbCSKey.Keys) > 0 + for keyID, key := range uploadReqCSKey.Keys { + // If dbKeysExist is false and we enter the loop, it means we have received at least one key that is not in the DB, and we want to persist it. + if !dbKeysExist { + return true + } + dbKey, ok := dbCSKey.Keys[keyID] + if !ok || !slices.Equal(dbKey, key) { + return true + } + } + return false + } + + if !hasDifferentKeys(device.UserID, &uploadReq.MasterKey, keysResp.MasterKeys) && + !hasDifferentKeys(device.UserID, &uploadReq.SelfSigningKey, keysResp.SelfSigningKeys) && + !hasDifferentKeys(device.UserID, &uploadReq.UserSigningKey, keysResp.UserSigningKeys) { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: map[int]interface{}{}, + } + } + } + if isCrossSigningSetup { // With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable. if cfg.MSCs.MSC3861Enabled() { @@ -83,7 +120,7 @@ func UploadCrossSigningDeviceKeys( }, }, strings.Join([]string{ - "To reset your end-to-end encryption cross-signing, identity, you first need to approve it at", + "To reset your end-to-end encryption cross-signing identity, you first need to approve it at", url, "and then try again.", }, " "), From 5cffc2c25732f913d4e9d962548d1ac9a083235a Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 6 Jan 2025 03:18:31 +0000 Subject: [PATCH 21/71] mas: fix displayname handling --- clientapi/routing/admin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index fcb167e22..b213d5aaf 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -840,7 +840,7 @@ type adminExternalID struct { } type adminCreateOrModifyAccountRequest struct { - DisplayName string `json:"display_name"` + DisplayName string `json:"displayname"` AvatarURL string `json:"avatar_url"` // TODO: the following fields are not used here, but they are used in Synapse. Probably we should reproduce the logic of the // endpoint fully compatible. From 811a504e013233e49ece7bc78d6b3ef93a221aeb Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Mon, 6 Jan 2025 03:19:47 +0000 Subject: [PATCH 22/71] mas: handle 3pids from mas --- clientapi/routing/admin.go | 63 +++++++++++++++++++++---------- userapi/api/api.go | 7 ++++ userapi/internal/user_api.go | 4 ++ userapi/storage/interface.go | 1 + userapi/storage/shared/storage.go | 35 +++++++++++++++++ 5 files changed, 91 insertions(+), 19 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index b213d5aaf..681121546 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -23,6 +23,7 @@ import ( appserviceAPI "github.com/element-hq/dendrite/appservice/api" clientapi "github.com/element-hq/dendrite/clientapi/api" + "github.com/element-hq/dendrite/clientapi/auth/authtypes" clienthttputil "github.com/element-hq/dendrite/clientapi/httputil" "github.com/element-hq/dendrite/clientapi/userutil" "github.com/element-hq/dendrite/internal/httputil" @@ -31,6 +32,7 @@ import ( "github.com/element-hq/dendrite/setup/jetstream" "github.com/element-hq/dendrite/userapi/api" userapi "github.com/element-hq/dendrite/userapi/api" + "github.com/element-hq/dendrite/userapi/storage/shared" ) const ( @@ -842,11 +844,13 @@ type adminExternalID struct { type adminCreateOrModifyAccountRequest struct { DisplayName string `json:"displayname"` AvatarURL string `json:"avatar_url"` - // TODO: the following fields are not used here, but they are used in Synapse. Probably we should reproduce the logic of the - // endpoint fully compatible. + ThreePIDs []struct { + Medium string `json:"medium"` + Address string `json:"address"` + } `json:"threepids"` + // TODO: the following fields are not used here, but they are used in Synapse. // Password string `json:"password"` // LogoutDevices bool `json:"logout_devices"` - // Threepids json.RawMessage `json:"threepids"` // ExternalIDs []adminExternalID `json:"external_ids"` // Admin bool `json:"admin"` // Deactivated bool `json:"deactivated"` @@ -872,24 +876,45 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr) return *resErr } - logger.Debugf("adminCreateOrModifyAccountRequest is: %+v", r) + logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r) statusCode := http.StatusOK - { - var res userapi.PerformAccountCreationResponse - err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ - AccountType: userapi.AccountTypeUser, - Localpart: local, - ServerName: domain, - OnConflict: api.ConflictUpdate, - AvatarURL: r.AvatarURL, - DisplayName: r.DisplayName, - }, &res) - if err != nil { - logger.WithError(err).Debugln("Failed creating account") - return util.MessageResponse(http.StatusBadRequest, err.Error()) + + // TODO: Ideally, the following commands should be executed in one transaction. + // can we propagate the tx object and pass it in context? + var res userapi.PerformAccountCreationResponse + err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: local, + ServerName: domain, + OnConflict: api.ConflictUpdate, + AvatarURL: r.AvatarURL, + DisplayName: r.DisplayName, + }, &res) + if err != nil { + logger.WithError(err).Error("userAPI.PerformAccountCreation") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if res.AccountCreated { + statusCode = http.StatusCreated + } + + if l := len(r.ThreePIDs); l > 0 { + logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs) + threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs)) + for i := range r.ThreePIDs { + tpid := &r.ThreePIDs[i] + threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address}) } - if res.AccountCreated { - statusCode = http.StatusCreated + err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{ + ThreePIDs: threePIDs, + Localpart: local, + ServerName: domain, + }, &struct{}{}) + if err == shared.Err3PIDInUse { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } else if err != nil { + logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation") + return util.ErrorResponse(err) } } diff --git a/userapi/api/api.go b/userapi/api/api.go index 2efa8976e..08308bc3b 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -111,6 +111,7 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error + PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error } type KeyBackupAPI interface { @@ -653,6 +654,12 @@ type PerformSaveThreePIDAssociationRequest struct { Medium string } +type PerformBulkSaveThreePIDAssociationRequest struct { + ThreePIDs []authtypes.ThreePID + Localpart string + ServerName spec.ServerName +} + type QueryAccountByLocalpartRequest struct { Localpart string ServerName spec.ServerName diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index a7760c1b2..b691496ec 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -987,4 +987,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } +func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error { + return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName) +} + const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 11b360952..2c0b4bf2a 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -120,6 +120,7 @@ type Pusher interface { type ThreePID interface { SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error) + BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e76cde61b..3d6d51edc 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -353,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation( }) } +// BulkSaveThreePIDAssociation recreates 3PIDs for a user. +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) + if err != nil { + return err + } + for _, t := range oldThreePIDs { + if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil { + return err + } + } + for _, t := range threePIDs { + // if 3PID is associated with another user, return Err3PIDInUse + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( + ctx, txn, t.Address, t.Medium, + ) + if err != nil { + return err + } + + if len(user) > 0 && user != localpart { + return Err3PIDInUse + } + + if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil { + return err + } + } + return nil + }) +} + // RemoveThreePIDAssociation removes the association involving a given third-party // identifier. // If no association exists involving this third-party identifier, returns nothing. From 17576cc6d23d23ea9ea89a41b2c6eee9ffadfa93 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 7 Jan 2025 01:18:39 +0000 Subject: [PATCH 23/71] mas: acced msc3861 config example to the dendrite-sample.yaml --- dendrite-sample.yaml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 0ee381f02..967d5cfb4 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -285,8 +285,29 @@ media_api: # Configuration for enabling experimental MSCs on this homeserver. mscs: mscs: + # - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861. MUST always go first in the list) # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # This block has no effect if the feature is not activated in the list above + # msc3861: + # enabled: true + + # # OIDC issuer advertised by the service. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#http + # issuer: "https://mas.example.com/" + + # # Credentials used for authenticating requests coming from dendrite to auth service. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#clients + # client_id: 01JFNM9MCHKV6A7A0C0RBHMYC0 + # client_secret: c85731184ac8f9aea76cf48146046b454473ca667a0cd1fd52a43034a0662eed + + # # The service token used for authenticating requests coming from auth service to dendrite. + # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#matrix + # admin_token: ttJORW9oV4Wf4DJ63GdZEYekE2KElP4g + + # # URL of the account page on the auth service side + # account_management_url: "https://mas.example.com/account" + # Configuration for the Sync API. sync_api: # This option controls which HTTP header to inspect to find the real remote IP From e943ba5c8d9ec437a72d568b91fe44ba550868bd Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 7 Jan 2025 01:20:13 +0000 Subject: [PATCH 24/71] mas: fail if conflicts in config occur Since MSC3861 is conflicting with standard reg/login flows, we require to disable them before running the server --- setup/config/config_clientapi.go | 61 +++++++++++++++++++------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 85dfe0beb..4f6a7ac20 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -74,34 +74,45 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) { func (c *ClientAPI) Verify(configErrs *ConfigErrors) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) - if c.RecaptchaEnabled { - if c.RecaptchaSiteVerifyAPI == "" { - c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" - } - if c.RecaptchaApiJsUrl == "" { - c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" + + if c.MSCs.MSC3861Enabled() { + if c.RecaptchaEnabled || !c.RegistrationDisabled { + configErrs.Add( + "You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC." + + "As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled." + + "You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.", + ) } - if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha-response" + } else { + if c.RecaptchaEnabled { + if c.RecaptchaSiteVerifyAPI == "" { + c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" + } + if c.RecaptchaApiJsUrl == "" { + c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" + } + if c.RecaptchaFormField == "" { + c.RecaptchaFormField = "g-recaptcha-response" + } + if c.RecaptchaSitekeyClass == "" { + c.RecaptchaSitekeyClass = "g-recaptcha" + } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } - if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha" + // Ensure there is any spam counter measure when enabling registration + if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled { + configErrs.Add( + "You have tried to enable open registration without any secondary verification methods " + + "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + + "increasing the risk that your server will be used to send spam or abuse, and may result in " + + "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + + "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + + "should set the registration_disabled option in your Dendrite config.", + ) } - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) - checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) - } - // Ensure there is any spam counter measure when enabling registration - if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled { - configErrs.Add( - "You have tried to enable open registration without any secondary verification methods " + - "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + - "increasing the risk that your server will be used to send spam or abuse, and may result in " + - "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + - "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + - "should set the registration_disabled option in your Dendrite config.", - ) } } From 7eec60e2effa1977514cfb2ce4099fa16e60e978 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Tue, 7 Jan 2025 02:28:20 +0000 Subject: [PATCH 25/71] mas: reorganise endpoints --- clientapi/routing/routing.go | 98 ++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 55a660fe5..bea7332f5 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -330,14 +330,15 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) - if m := mscCfg.MSC3861; mscCfg.Enabled("msc3861") && m != nil && m.Enabled { + if mscCfg.MSC3861Enabled() { + m := mscCfg.MSC3861 + unstableMux.Handle("/org.matrix.msc2965/auth_issuer", httputil.MakeExternalAPI("auth_issuer", func(r *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]string{ "issuer": m.Issuer, }} })).Methods(http.MethodGet) - synapseAdminRouter.Handle("/admin/v1/username_available", httputil.MakeServiceAdminAPI("admin_username_available", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminCheckUsernameAvailable(r, userAPI, cfg) @@ -357,7 +358,6 @@ func Setup( return util.JSONResponse{Code: http.StatusMethodNotAllowed, JSON: nil} } })).Methods(http.MethodPut, http.MethodGet) - synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", httputil.MakeServiceAdminAPI("admin_create_retrieve_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminUserDeviceRetrieveCreate(r, userAPI, cfg) @@ -370,11 +370,57 @@ func Setup( httputil.MakeServiceAdminAPI("admin_delete_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminUserDevicesDelete(r, userAPI, cfg) })).Methods(http.MethodPost) - synapseAdminRouter.Handle("/admin/v1/users/{userID}/_allow_cross_signing_replacement_without_uia", httputil.MakeServiceAdminAPI("admin_allow_cross_signing_replacement_without_uia", m.AdminToken, func(r *http.Request) util.JSONResponse { return AdminAllowCrossSigningReplacementWithoutUIA(r, userAPI) })).Methods(http.MethodPost) + } else { + // If msc3861 is enabled, these endpoints are either redundant or replaced by Matrix Auth Service (MAS) + // Once we migrate to MAS completely, these edndpoints should be removed + + v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return Register(req, userAPI, cfg) + })).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return RegisterAvailable(req, cfg, userAPI) + })).Methods(http.MethodGet, http.MethodOptions) + + // Stub endpoints required by Element + + v3mux.Handle("/login", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + if r := rateLimits.Limit(req, nil); r != nil { + return *r + } + return Login(req, userAPI, cfg) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + + v3mux.Handle("/auth/{authType}/fallback/web", + httputil.MakeHTTPAPI("auth_fallback", userVerifier, enableMetrics, func(w http.ResponseWriter, req *http.Request) { + vars := mux.Vars(req) + AuthFallback(w, req, vars["authType"], cfg) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + + v3mux.Handle("/logout", + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return Logout(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + v3mux.Handle("/logout/all", + httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return LogoutAll(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) } if mscCfg.Enabled("msc2753") { @@ -577,20 +623,6 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return Register(req, userAPI, cfg) - })).Methods(http.MethodPost, http.MethodOptions) - - v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return RegisterAvailable(req, cfg, userAPI) - })).Methods(http.MethodGet, http.MethodOptions) - v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -666,18 +698,6 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - v3mux.Handle("/logout", - httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return Logout(req, userAPI, device) - }), - ).Methods(http.MethodPost, http.MethodOptions) - - v3mux.Handle("/logout/all", - httputil.MakeAuthAPI("logout", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return LogoutAll(req, userAPI, device) - }), - ).Methods(http.MethodPost, http.MethodOptions) - v3mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { @@ -762,24 +782,6 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - // Stub endpoints required by Element - - v3mux.Handle("/login", - httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - if r := rateLimits.Limit(req, nil); r != nil { - return *r - } - return Login(req, userAPI, cfg) - }), - ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - - v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTTPAPI("auth_fallback", userVerifier, enableMetrics, func(w http.ResponseWriter, req *http.Request) { - vars := mux.Vars(req) - AuthFallback(w, req, vars["authType"], cfg) - }), - ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - // Push rules v3mux.Handle("/pushrules", From fb15db7f7a3389b26115f5263034f848f1f40d4f Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 8 Jan 2025 10:11:27 +0000 Subject: [PATCH 26/71] unit tests fix --- appservice/appservice_test.go | 6 ++-- clientapi/admin_test.go | 4 +-- clientapi/clientapi_test.go | 49 ++++++++++++++++++++---------- clientapi/routing/login_test.go | 4 ++- roomserver/roomserver_test.go | 4 ++- setup/mscs/msc2836/msc2836_test.go | 4 ++- 6 files changed, 48 insertions(+), 23 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index b7cd88562..976a62725 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/element-hq/dendrite/clientapi" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/httputil" @@ -446,7 +447,8 @@ func TestOutputAppserviceEvent(t *testing.T) { } usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: usrAPI} + clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, &userVerifier, caching.DisableMetrics) createAccessTokens(t, accessTokens, usrAPI, processCtx.Context(), routers) room := test.NewRoom(t, alice) @@ -537,7 +539,7 @@ func TestOutputAppserviceEvent(t *testing.T) { } // Start the syncAPI to have `/joined_members` available - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, caching.DisableMetrics) + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // start the consumer appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 179e91407..5746aec94 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -801,14 +801,14 @@ func TestPurgeRoom(t *testing.T) { rsAPI.SetFederationAPI(fsAPI, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index ad2d4ad48..b844699da 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/element-hq/dendrite/appservice" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/clientapi/routing" "github.com/element-hq/dendrite/clientapi/threepid" @@ -127,9 +128,10 @@ func TestGetPutDevices(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -176,9 +178,10 @@ func TestDeleteDevice(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -281,9 +284,10 @@ func TestDeleteDevices(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -449,8 +453,9 @@ func TestSetDisplayname(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -561,8 +566,9 @@ func TestSetAvatarURL(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -638,8 +644,9 @@ func TestTyping(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) // Needed to create accounts userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -723,8 +730,9 @@ func TestMembership(t *testing.T) { // Needed to create accounts userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) rsAPI.SetUserAPI(userAPI) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -962,8 +970,9 @@ func TestCapabilities(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1010,9 +1019,10 @@ func TestTurnserver(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} //rsAPI.SetUserAPI(userAPI) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1109,8 +1119,9 @@ func Test3PID(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) // Create the users in the userapi and login accessTokens := map[*test.User]userDevice{ @@ -1285,9 +1296,10 @@ func TestPushRules(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -1672,9 +1684,10 @@ func TestKeys(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2134,9 +2147,10 @@ func TestKeyBackup(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2238,9 +2252,10 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2301,9 +2316,10 @@ func TestCreateRoomInvite(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, @@ -2376,9 +2392,10 @@ func TestReportEvent(t *testing.T) { if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ alice: {}, diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index b987e6f23..f88dc6690 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" @@ -50,9 +51,10 @@ func TestLogin(t *testing.T) { rsAPI.SetFederationAPI(nil, nil) // Needed for /login userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics) + Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, &userVerifier, caching.DisableMetrics) // Create password password := util.RandomString(8) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 48911d2bb..01d4b47dd 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/eventutil" @@ -267,7 +268,8 @@ func TestPurgeRoom(t *testing.T) { rsAPI.SetFederationAPI(fsAPI, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, fsAPI.IsBlacklistedOrBackingOff) - syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics) // Create the room if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 5b85e6707..024f175c7 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/setup/process" "github.com/element-hq/dendrite/syncapi/synctypes" "github.com/gorilla/mux" @@ -571,7 +572,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve processCtx := process.NewProcessContext() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) routers := httputil.NewRouters() - err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, userAPI, nil) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, &userVerifier, nil) if err != nil { t.Fatalf("failed to enable MSC2836: %s", err) } From b44a79c15957447245a2b2047838fd6b7f585f52 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 8 Jan 2025 11:31:01 +0000 Subject: [PATCH 27/71] Bump golang version --- Dockerfile | 2 +- build/scripts/Complement.Dockerfile | 2 +- build/scripts/ComplementLocal.Dockerfile | 2 +- build/scripts/ComplementPostgres.Dockerfile | 2 +- cmd/dendrite-upgrade-tests/main.go | 4 ++-- go.mod | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 27e7b39ad..6c2e0cfa6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # # base installs required dependencies and runs go mod download to cache dependencies # -FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22-alpine AS base +FROM --platform=${BUILDPLATFORM} docker.io/golang:1.23-alpine AS base RUN apk --update --no-cache add bash build-base curl git # diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 660b84a46..e23aad8bd 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -1,6 +1,6 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.22-bookworm as build +FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 8fc847650..c2af16495 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -8,7 +8,7 @@ # # Use these mounts to make use of this dockerfile: # COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro' -FROM golang:1.22-bookworm +FROM golang:1.23-bookworm RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 0026842d8..48843eb08 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -1,6 +1,6 @@ #syntax=docker/dockerfile:1.2 -FROM golang:1.22-bookworm as build +FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index e1ac179a7..519d5e47d 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -55,7 +55,7 @@ var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD" // due to the error: // When using COPY with more than one source file, the destination must be a directory and end with a / // We need to run a postgres anyway, so use the dockerfile associated with Complement instead. -const DockerfilePostgreSQL = `FROM golang:1.22-bookworm as build +const DockerfilePostgreSQL = `FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build ARG BINARY @@ -99,7 +99,7 @@ ENV BINARY=dendrite EXPOSE 8008 8448 CMD /build/run_dendrite.sh` -const DockerfileSQLite = `FROM golang:1.22-bookworm as build +const DockerfileSQLite = `FROM golang:1.23-bookworm as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build ARG BINARY diff --git a/go.mod b/go.mod index 044531252..7e06503be 100644 --- a/go.mod +++ b/go.mod @@ -160,6 +160,6 @@ require ( nhooyr.io/websocket v1.8.7 // indirect ) -go 1.22 +go 1.23 toolchain go1.23.2 From 7311d3e1de41b42f5a56791290bf0805ee2570aa Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 8 Jan 2025 12:43:22 +0000 Subject: [PATCH 28/71] more fixes --- syncapi/syncapi_test.go | 48 +++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 4e1fa7dfb..fbc8ac471 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" "github.com/element-hq/dendrite/internal/sqlutil" @@ -27,6 +28,7 @@ import ( "github.com/element-hq/dendrite/syncapi/storage" "github.com/element-hq/dendrite/syncapi/synctypes" + "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/producers" "github.com/element-hq/dendrite/roomserver" "github.com/element-hq/dendrite/roomserver/api" @@ -35,9 +37,14 @@ import ( "github.com/element-hq/dendrite/syncapi/types" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" + usrapi "github.com/element-hq/dendrite/userapi" userapi "github.com/element-hq/dendrite/userapi/api" ) +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + type syncRoomserverAPI struct { rsapi.SyncRoomserverAPI rooms []*test.Room @@ -141,12 +148,15 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} defer close() jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -241,12 +251,14 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} defer close() jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -399,7 +411,9 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, cfg, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -487,7 +501,9 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &userVerifier, caching.DisableMetrics) w := httptest.NewRecorder() routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -609,7 +625,9 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // Use the actual internal roomserver API rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -879,7 +897,10 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -946,10 +967,12 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() natsInstance := jetstream.NATSInstance{} + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &userVerifier, caching.DisableMetrics) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -1172,7 +1195,10 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, caching.DisableMetrics) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) room := test.NewRoom(t, user) @@ -1352,8 +1378,11 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) + userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + room := test.NewRoom(t, user) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) @@ -1416,6 +1445,7 @@ func searchRequest(t *testing.T, router *mux.Router, accessToken, searchTerm str assert.NoError(t, err) return body } + func syncUntil(t *testing.T, routers httputil.Routers, accessToken string, skip bool, From 099067646647e3a7a6501e4e45020ba709c11c09 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 9 Jan 2025 01:16:11 +0000 Subject: [PATCH 29/71] linter fixes --- clientapi/routing/admin.go | 30 +++++++++---------- clientapi/routing/profile.go | 2 +- setup/mscs/msc3861/msc3861_user_verifier.go | 9 +++--- syncapi/syncapi_test.go | 2 +- .../postgres/cross_signing_keys_table.go | 2 +- .../sqlite3/cross_signing_keys_table.go | 2 +- 6 files changed, 23 insertions(+), 24 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 681121546..cc073bdbd 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -539,7 +539,7 @@ func AdminUserDeviceRetrieveCreate( if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, _ := vars["userID"] + userID := vars["userID"] local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) if err != nil { return util.JSONResponse{ @@ -567,7 +567,7 @@ func AdminUserDeviceRetrieveCreate( userDeviceExists := false { var rs api.QueryDevicesResponse - if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { logger.WithError(err).Error("QueryDevices") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -590,7 +590,7 @@ func AdminUserDeviceRetrieveCreate( if !userDeviceExists { var rs userapi.PerformDeviceCreationResponse - if err := userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + if err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ Localpart: local, ServerName: domain, DeviceID: &payload.DeviceID, @@ -656,8 +656,8 @@ func AdminUserDeviceDelete( if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, _ := vars["userID"] - deviceID, _ := vars["deviceID"] + userID := vars["userID"] + deviceID := vars["deviceID"] logger := util.GetLogger(req.Context()) // XXX: we probably have to delete session from the sessions dict @@ -718,13 +718,13 @@ func AdminUserDevicesDelete( if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, _ := vars["userID"] + userID := vars["userID"] var payload struct { Devices []string `json:"devices"` } - defer req.Body.Close() + defer req.Body.Close() // nolint: errcheck if err = json.NewDecoder(req.Body).Decode(&payload); err != nil { logger.WithError(err).Error("unable to decode device deletion request") return util.JSONResponse{ @@ -765,7 +765,7 @@ func AdminDeactivateAccount( if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, _ := vars["userID"] + userID := vars["userID"] local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) @@ -836,11 +836,6 @@ func AdminAllowCrossSigningReplacementWithoutUIA( } -type adminExternalID struct { - AuthProvider string `json:"auth_provider"` - ExternalID string `json:"external_id"` -} - type adminCreateOrModifyAccountRequest struct { DisplayName string `json:"displayname"` AvatarURL string `json:"avatar_url"` @@ -848,10 +843,13 @@ type adminCreateOrModifyAccountRequest struct { Medium string `json:"medium"` Address string `json:"address"` } `json:"threepids"` - // TODO: the following fields are not used here, but they are used in Synapse. + // TODO: the following fields are not used by dendrite, but they are used in Synapse. // Password string `json:"password"` // LogoutDevices bool `json:"logout_devices"` - // ExternalIDs []adminExternalID `json:"external_ids"` + // ExternalIDs []struct{ + // AuthProvider string `json:"auth_provider"` + // ExternalID string `json:"external_id"` + // } `json:"external_ids"` // Admin bool `json:"admin"` // Deactivated bool `json:"deactivated"` // Locked bool `json:"locked"` @@ -863,7 +861,7 @@ func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - userID, _ := vars["userID"] + userID := vars["userID"] local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 74bbddbc7..922d5e901 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -204,7 +204,7 @@ func SetDisplayName( // to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue, // we need to replace the admin's device with the user's device var rs userapi.QueryDevicesResponse - err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs) + err = userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 59a4c3aa1..88aa11f9e 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -148,6 +148,7 @@ type requester struct { IsGuest bool } +// nolint: gocyclo func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token string) (*requester, error) { var userID *spec.UserID logger := util.GetLogger(ctx) @@ -220,7 +221,7 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st var account *api.Account { var rs api.QueryAccountByLocalpartResponse - err := m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs) + err = m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs) if err != nil && err != sql.ErrNoRows { logger.WithError(err).Error("QueryAccountByLocalpart") return nil, err @@ -241,7 +242,7 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } } - if err := m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ + if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ Localpart: userID.Local(), ExternalID: sub, AuthProvider: externalAuthProvider, @@ -348,7 +349,7 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) return nil, err } body := resp.Body - defer resp.Body.Close() + defer resp.Body.Close() // nolint: errcheck if c := resp.StatusCode; c < 200 || c >= 300 { return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) @@ -405,7 +406,7 @@ func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (*Ope if err != nil { return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() // nolint: errcheck if resp.StatusCode != http.StatusOK { return nil, &mscError{Code: codeOpenidConfigEndpointNon2xx, Msg: ".well-known/openid-configuration endpoint returned non-200 response"} } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index fbc8ac471..b06500740 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -626,7 +626,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) for _, tc := range testCases { diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 66ccc9a33..61b2c67d4 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -148,7 +148,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { - keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] + keyTypeInt := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] ts := time.Now().Add(duration).UnixMilli() result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) if err != nil { diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index c86cf2063..b86bfdf01 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -147,7 +147,7 @@ func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { - keyTypeInt, _ := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] + keyTypeInt := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] ts := time.Now().Add(duration).UnixMilli() result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) if err != nil { From 1afe2b90eac3aa7c39b750d1a53a602fe99f14f0 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 9 Jan 2025 01:44:42 +0000 Subject: [PATCH 30/71] fix cross_signing_keys_table --- userapi/storage/postgres/cross_signing_keys_table.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 61b2c67d4..70592cf63 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -39,8 +39,8 @@ const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + " WHERE user_id = $1 AND key_type = $2" const upsertCrossSigningKeysForUserSQL = "" + - "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" + + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data, updatable_without_uia_before_ms)" + + " VALUES($1, $2, $3, $4)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + From 244021dc055846bdf03a9753e351eed519495d78 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 9 Jan 2025 02:06:37 +0000 Subject: [PATCH 31/71] deleted test cases TestDevices/sqlite/dupe_token these testcases are irrelevant msc3861 because access tokens are supposed to come from mas and access_token field is not used at all --- userapi/internal/user_api.go | 2 ++ userapi/userapi_test.go | 11 ----------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index b691496ec..2b500c95d 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -306,6 +306,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") + // TODO: Since we have deleted access_token's unique constraint from the db, + // we probably should check its uniqueness if msc3861 is disabled dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 6e33ced01..3ff6adfb3 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -445,8 +445,6 @@ func TestAccountData(t *testing.T) { func TestDevices(t *testing.T) { ctx := context.Background() - dupeAccessToken := util.RandomString(8) - displayName := "testing" creationTests := []struct { @@ -468,15 +466,6 @@ func TestDevices(t *testing.T) { name: "explicit local user", inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, }, - { - name: "dupe token - ok", - inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, - }, - { - name: "dupe token - not ok", - inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, - wantErr: true, - }, { name: "test3 second device", // used to test deletion later inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, From 78457f30ed08a32fc1236a077acd916be7f4bdcd Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 9 Jan 2025 02:29:42 +0000 Subject: [PATCH 32/71] ++ --- syncapi/syncapi_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b06500740..b3dbf86cb 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -148,15 +148,13 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} defer close() jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { From 80ee52e09242b488c799c0646c01e4fbd30f5467 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 9 Jan 2025 03:03:10 +0000 Subject: [PATCH 33/71] fix syncapi tests --- syncapi/syncapi_test.go | 142 ++++++++++++++++++++++++++++++---------- 1 file changed, 108 insertions(+), 34 deletions(-) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b3dbf86cb..bd322f664 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/element-hq/dendrite/federationapi/statistics" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" "github.com/element-hq/dendrite/internal/sqlutil" @@ -19,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" @@ -28,7 +28,6 @@ import ( "github.com/element-hq/dendrite/syncapi/storage" "github.com/element-hq/dendrite/syncapi/synctypes" - "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/producers" "github.com/element-hq/dendrite/roomserver" "github.com/element-hq/dendrite/roomserver/api" @@ -37,14 +36,9 @@ import ( "github.com/element-hq/dendrite/syncapi/types" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" - usrapi "github.com/element-hq/dendrite/userapi" userapi "github.com/element-hq/dendrite/userapi/api" ) -var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { - return &statistics.ServerStatistics{}, nil -} - type syncRoomserverAPI struct { rsapi.SyncRoomserverAPI rooms []*test.Room @@ -126,6 +120,20 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe return nil } +type userVerifier struct { + m map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + } +} + +func (u *userVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { + if pair, ok := u.m[req.URL.Query().Get("access_token")]; ok { + return pair.Device, pair.Response + } + return nil, nil +} + func TestSyncAPIAccessTokens(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { testSyncAccessTokens(t, dbType) @@ -153,13 +161,16 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) + uv := &userVerifier{} - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, uv, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { name string req *http.Request + device *userapi.Device + response *util.JSONResponse wantCode int wantJoinedRooms []string }{ @@ -168,6 +179,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { req: test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "timeout": "0", })), + device: nil, + response: &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + }, wantCode: 401, }, { @@ -176,6 +192,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { "access_token": "foo", "timeout": "0", })), + device: nil, + response: &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.UnknownToken("Unknown token"), + }, wantCode: 401, }, { @@ -184,11 +205,25 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { "access_token": alice.AccessToken, "timeout": "0", })), + device: &alice, + response: nil, wantCode: 200, wantJoinedRooms: []string{room.ID}, }, } + uv.m = make(map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }, len(testCases)) + for _, tc := range testCases { + + uv.m[tc.req.URL.Query().Get("access_token")] = struct { + Device *userapi.Device + Response *util.JSONResponse + }{Device: tc.device, Response: tc.response} + } + syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool { // wait for the last sent eventID to come down sync path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) @@ -249,14 +284,20 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } defer close() jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -409,9 +450,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, cfg, room.Events()...) sinceTokens := make([]string, len(msgs)) - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -499,9 +538,15 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &userVerifier, caching.DisableMetrics) + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics) w := httptest.NewRecorder() routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -623,9 +668,15 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // Use the actual internal roomserver API rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + bobDev.AccessToken: {Device: &bobDev, Response: nil}, + }, + } + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -894,11 +945,17 @@ func TestGetMembership(t *testing.T) { // Use an actual roomserver for this rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + aliceDev.AccessToken: {Device: &aliceDev, Response: nil}, + bobDev.AccessToken: {Device: &bobDev, Response: nil}, + }, + } - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} - - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -965,12 +1022,18 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() natsInstance := jetstream.NATSInstance{} - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -1193,10 +1256,16 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &uv, caching.DisableMetrics) room := test.NewRoom(t, user) @@ -1375,12 +1444,17 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - - userAPI := usrapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) - userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI} + uv := userVerifier{ + m: map[string]struct { + Device *userapi.Device + Response *util.JSONResponse + }{ + alice.AccessToken: {Device: &alice, Response: nil}, + }, + } room := test.NewRoom(t, user) - AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &userVerifier, caching.DisableMetrics) + AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics) if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) From 930daa109000f59e427f1f791baf54c284e78332 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 10 Jan 2025 00:51:29 +0000 Subject: [PATCH 34/71] mas: move org.matrix.cross_signing_reset const from logintypes.go to key_crosssigning.go --- clientapi/auth/authtypes/logintypes.go | 1 - clientapi/routing/key_crosssigning.go | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index c6e67f315..f01e48f80 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,5 +11,4 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" - LoginTypeCrossSigningReset = "org.matrix.cross_signing_reset" ) diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index d3950e75c..94e660400 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/util" ) +const CrossSigningResetStage = "org.matrix.cross_signing_reset" + type crossSigningRequest struct { api.PerformUploadDeviceKeysRequest Auth newPasswordAuth `json:"auth"` @@ -101,7 +103,7 @@ func UploadCrossSigningDeviceKeys( if !masterKeyUpdatableWithoutUIA { url := "" if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { - url = strings.Join([]string{m.AccountManagementURL, "?action=", authtypes.LoginTypeCrossSigningReset}, "") + url = strings.Join([]string{m.AccountManagementURL, "?action=", CrossSigningResetStage}, "") } else { url = m.Issuer } @@ -111,11 +113,11 @@ func UploadCrossSigningDeviceKeys( "dummy", []authtypes.Flow{ { - Stages: []authtypes.LoginType{authtypes.LoginTypeCrossSigningReset}, + Stages: []authtypes.LoginType{CrossSigningResetStage}, }, }, map[string]interface{}{ - authtypes.LoginTypeCrossSigningReset: map[string]string{ + CrossSigningResetStage: map[string]string{ "url": url, }, }, @@ -128,7 +130,7 @@ func UploadCrossSigningDeviceKeys( } } // XXX: is it necessary? - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeCrossSigningReset) + sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage) } else { if uploadReq.Auth.Type != authtypes.LoginTypePassword { return util.JSONResponse{ From 0be9b3ca5404de02b6d04663f1cbf7a80ca5ab9e Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 10 Jan 2025 00:54:35 +0000 Subject: [PATCH 35/71] syncapi_test.go fix --- syncapi/syncapi_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index bd322f664..88db32083 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -673,7 +673,8 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { Device *userapi.Device Response *util.JSONResponse }{ - bobDev.AccessToken: {Device: &bobDev, Response: nil}, + aliceDev.AccessToken: {Device: &aliceDev, Response: nil}, + bobDev.AccessToken: {Device: &bobDev, Response: nil}, }, } AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics) From 4cde3bafb16e6805ee011e0579df215cd20edd5e Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 10 Jan 2025 01:23:31 +0000 Subject: [PATCH 36/71] mas: add missing migration for adding x-signing updatable_without_uia_before_ms field --- .../postgres/cross_signing_keys_table.go | 14 ++++++++++- ...signing_updatable_without_uia_before_ms.go | 23 +++++++++++++++++++ .../sqlite3/cross_signing_keys_table.go | 14 ++++++++++- ...signing_updatable_without_uia_before_ms.go | 23 +++++++++++++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go create mode 100644 userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 70592cf63..61b15b884 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -12,6 +12,8 @@ import ( "fmt" "time" + "github.com/element-hq/dendrite/userapi/storage/postgres/deltas" + "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/userapi/storage/tables" @@ -25,7 +27,6 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type SMALLINT NOT NULL, key_data TEXT NOT NULL, - updatable_without_uia_before_ms BIGINT DEFAULT NULL, PRIMARY KEY (user_id, key_type) ); ` @@ -64,6 +65,17 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro if err != nil { return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add x-signing updatable_without_uia_before_ms", + Up: deltas.UpAddXSigningUpdatableWithoutUIABeforeMs, + }, + ) + err = m.Up(context.Background()) + if err != nil { + return nil, err + } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, diff --git a/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go b/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go new file mode 100644 index 000000000..11f32aecb --- /dev/null +++ b/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go @@ -0,0 +1,23 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys ADD COLUMN IF NOT EXISTS updatable_without_uia_before_ms BIGINT DEFAULT NULL;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys DROP COLUMN IF EXISTS updatable_without_uia_before_ms;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index b86bfdf01..19f9b6c19 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -12,6 +12,8 @@ import ( "fmt" "time" + "github.com/element-hq/dendrite/userapi/storage/sqlite3/deltas" + "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" "github.com/element-hq/dendrite/userapi/storage/tables" @@ -25,7 +27,6 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type INTEGER NOT NULL, key_data TEXT NOT NULL, - updatable_without_uia_before_ms BIGINT DEFAULT NULL, PRIMARY KEY (user_id, key_type) ); ` @@ -63,6 +64,17 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) if err != nil { return nil, err } + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "userapi: add x-signing updatable_without_uia_before_ms", + Up: deltas.UpAddXSigningUpdatableWithoutUIABeforeMs, + }, + ) + err = m.Up(context.Background()) + if err != nil { + return nil, err + } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, diff --git a/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go b/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go new file mode 100644 index 000000000..2935509a7 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go @@ -0,0 +1,23 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys ADD COLUMN updatable_without_uia_before_ms BIGINT DEFAULT NULL;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys DROP COLUMN updatable_without_uia_before_ms;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} From 5ea033d1e46067a3f496694c72cf279a873d9cd4 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 10 Jan 2025 02:12:30 +0000 Subject: [PATCH 37/71] mas: remove enabled field from msc3861 config + remove some incorrect comments --- dendrite-sample.yaml | 4 +--- setup/config/config_mscs.go | 10 +++------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 967d5cfb4..279d3581b 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -285,13 +285,11 @@ media_api: # Configuration for enabling experimental MSCs on this homeserver. mscs: mscs: - # - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861. MUST always go first in the list) + # - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861) # - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) # This block has no effect if the feature is not activated in the list above # msc3861: - # enabled: true - # # OIDC issuer advertised by the service. # # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#http # issuer: "https://mas.example.com/" diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 1523a9ce8..694bb3513 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -6,7 +6,7 @@ type MSCs struct { Matrix *Global `yaml:"-"` // The MSCs to enable. Supported MSCs include: - // 'msc3861': Delegate auth to an OIDC provider. This line MUST always go first if the msc is used https://github.com/matrix-org/matrix-spec-proposals/pull/3861 + // 'msc3861': Delegate auth to an OIDC provider - https://github.com/matrix-org/matrix-spec-proposals/pull/3861 // 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444 // 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753 // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 @@ -40,17 +40,16 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { if c.Matrix.DatabaseOptions.ConnectionString == "" { checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) } - if m := c.MSC3861; m != nil { + if m := c.MSC3861; m != nil && c.MSC3861Enabled() { m.Verify(configErrs) } } func (c *MSCs) MSC3861Enabled() bool { - return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil && c.MSC3861.Enabled + return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil } type MSC3861 struct { - Enabled bool `yaml:"enabled"` Issuer string `yaml:"issuer"` ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` @@ -59,9 +58,6 @@ type MSC3861 struct { } func (m *MSC3861) Verify(configErrs *ConfigErrors) { - if !m.Enabled { - return - } checkNotEmpty(configErrs, "mscs.msc3861.issuer", string(m.Issuer)) checkNotEmpty(configErrs, "mscs.msc3861.client_id", string(m.ClientID)) checkNotEmpty(configErrs, "mscs.msc3861.client_secret", string(m.ClientSecret)) From 5fd654f8ea49966893f3283fb09b96e96d25e533 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 10 Jan 2025 23:42:03 +0000 Subject: [PATCH 38/71] Add TestMakeServiceAdminAPI --- internal/httputil/httpapi_test.go | 67 +++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/internal/httputil/httpapi_test.go b/internal/httputil/httpapi_test.go index 23797a5ea..c9dd933cf 100644 --- a/internal/httputil/httpapi_test.go +++ b/internal/httputil/httpapi_test.go @@ -10,6 +10,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/matrix-org/util" ) func TestWrapHandlerInBasicAuth(t *testing.T) { @@ -99,3 +101,68 @@ func TestWrapHandlerInBasicAuth(t *testing.T) { }) } } + +func TestMakeServiceAdminAPI(t *testing.T) { + serviceToken := "valid_secret_token" + type args struct { + f func(*http.Request) util.JSONResponse + serviceToken string + } + + f := func(*http.Request) util.JSONResponse { + return util.JSONResponse{Code: http.StatusOK} + } + + tests := []struct { + name string + args args + want int + reqAuth bool + }{ + { + name: "service token valid", + args: args{ + f: f, + serviceToken: serviceToken, + }, + want: http.StatusOK, + reqAuth: true, + }, + { + name: "service token invalid", + args: args{ + f: f, + serviceToken: "invalid_service_token", + }, + want: http.StatusForbidden, + reqAuth: true, + }, + { + name: "service token is missing", + args: args{ + f: f, + serviceToken: "", + }, + want: http.StatusUnauthorized, + reqAuth: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := MakeServiceAdminAPI("metrics", serviceToken, tt.args.f) + + req := httptest.NewRequest("GET", "http://localhost/admin/v1/username_available", nil) + if tt.reqAuth { + req.Header.Add("Authorization", "Bearer "+tt.args.serviceToken) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + resp := w.Result() + + if resp.StatusCode != tt.want { + t.Errorf("Expected status code %d, got %d", resp.StatusCode, tt.want) + } + }) + } +} From 5914661ad2b1c422c632bbb3d4d99c1e714d0472 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 01:59:06 +0000 Subject: [PATCH 39/71] mas: add TestVerifyUserFromRequest --- setup/mscs/msc3861/msc3861.go | 1 + setup/mscs/msc3861/msc3861_user_verifier.go | 14 +- .../msc3861/msc3861_user_verifier_test.go | 223 ++++++++++++++++++ 3 files changed, 230 insertions(+), 8 deletions(-) create mode 100644 setup/mscs/msc3861/msc3861_user_verifier_test.go diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index 9b38af31f..b3c458b1b 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -8,6 +8,7 @@ func Enable(m *setup.Monolith) error { userVerifier, err := newMSC3861UserVerifier( m.UserAPI, m.Config.Global.ServerName, m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, + nil, ) if err != nil { return err diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 88aa11f9e..fbfdcaa24 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -1,6 +1,7 @@ package msc3861 import ( + "cmp" "context" "database/sql" "encoding/json" @@ -54,8 +55,10 @@ func newMSC3861UserVerifier( serverName spec.ServerName, cfg *config.MSC3861, allowGuest bool, + httpClient *http.Client, ) (*MSC3861UserVerifier, error) { - openIdConfig, err := fetchOpenIDConfiguration(&http.Client{}, cfg.Issuer) + client := cmp.Or(httpClient, http.DefaultClient) + openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer) if err != nil { return nil, err } @@ -65,7 +68,7 @@ func newMSC3861UserVerifier( cfg: cfg, openIdConfig: openIdConfig, allowGuest: allowGuest, - httpClient: http.DefaultClient, + httpClient: client, }, nil } @@ -105,12 +108,7 @@ func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev Code: http.StatusUnauthorized, JSON: spec.UnknownToken(e.Error()), } - case codeAuthError: - return nil, &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown(e.Error()), - } - case codeMxidError: + case codeAuthError, codeMxidError: return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown(e.Error()), diff --git a/setup/mscs/msc3861/msc3861_user_verifier_test.go b/setup/mscs/msc3861/msc3861_user_verifier_test.go new file mode 100644 index 000000000..8681decdd --- /dev/null +++ b/setup/mscs/msc3861/msc3861_user_verifier_test.go @@ -0,0 +1,223 @@ +package msc3861 + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "errors" + + "github.com/element-hq/dendrite/federationapi/statistics" + "github.com/element-hq/dendrite/internal/caching" + "github.com/element-hq/dendrite/internal/sqlutil" + "github.com/element-hq/dendrite/roomserver" + "github.com/element-hq/dendrite/setup/config" + "github.com/element-hq/dendrite/setup/jetstream" + "github.com/element-hq/dendrite/test" + "github.com/element-hq/dendrite/test/testrig" + "github.com/element-hq/dendrite/userapi" + uapi "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + +type roundTripper struct{} + +func (rt *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + var ( + respBody string + statusCode int + ) + + switch request.URL.String() { + case "https://mas.example.com/.well-known/openid-configuration": + respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}` + statusCode = http.StatusOK + case "https://mas.example.com/oauth2/introspect": + _ = request.ParseForm() + + switch request.Form.Get("token") { + case "validTokenUserExistsTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "111111111111111111", + Username: "1", + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserDoesNotExistTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "222222222222222222", + Username: "2", + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserExistsTokenInactive": + statusCode = http.StatusOK + resp := introspectionResponse{Active: false} + b, _ := json.Marshal(resp) + respBody = string(b) + default: + return nil, errors.New("Request URL not supported by stub") + } + } + + respReader := io.NopCloser(strings.NewReader(respBody)) + resp := http.Response{ + StatusCode: statusCode, + Body: respReader, + ContentLength: int64(len(respBody)), + Header: map[string][]string{"Content-Type": {"application/json"}}, + } + return &resp, nil +} + +func TestVerifyUserFromRequest(t *testing.T) { + httpClient := http.Client{ + Transport: &roundTripper{}, + } + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{ + Issuer: "https://mas.example.com", + } + cfg.ClientAPI.RateLimiting.Enabled = false + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed for /login + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + userVerifier, err := newMSC3861UserVerifier( + userAPI, + cfg.Global.ServerName, + cfg.MSCs.MSC3861, + false, + &httpClient, + ) + if err != nil { + t.Fatal(err.Error()) + } + u, _ := url.Parse("https://example.com/something") + + aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + + t.Run("existing user and active token", func(t *testing.T) { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', aliceUser.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: aliceUser.AccountType, + Localpart: localpart, + ServerName: serverName, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserExistsTokenActive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp != nil { + t.Fatalf("JSONResponse is not expected: %+v", jsonResp) + } + deviceRes := uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{ + UserID: aliceUser.ID, + }, &deviceRes); err != nil { + t.Errorf("failed to query user devices") + } + if !deviceRes.UserExists { + t.Fatalf("user does not exist") + } + if l := len(deviceRes.Devices); l != 1 { + t.Fatalf("Incorrect number of user devices. Got %d, want 1", l) + } + if device.ID != deviceRes.Devices[0].ID { + t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID) + } + }) + + t.Run("inactive token", func(t *testing.T) { + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserExistsTokenInactive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp == nil { + t.Fatal("JSONResponse is expected to be nil") + } + if device != nil { + t.Fatalf("Device is not nil: %+v", device) + } + if jsonResp.Code != http.StatusUnauthorized { + t.Fatalf("Incorrect status code: want=401, got=%d", jsonResp.Code) + } + mErr, _ := jsonResp.JSON.(spec.MatrixError) + if mErr.ErrCode != spec.ErrorUnknownToken { + t.Fatalf("Unexpected error code: want=%s, got=%s", spec.ErrorUnknownToken, mErr.ErrCode) + } + }) + + t.Run("non-existing user", func(t *testing.T) { + httpReq := http.Request{ + URL: u, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer validTokenUserDoesNotExistTokenActive"}, + }, + } + device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq) + if jsonResp != nil { + t.Fatalf("JSONResponse is not expected: %+v", jsonResp) + } + deviceRes := uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{ + UserID: bobUser.ID, + }, &deviceRes); err != nil { + t.Errorf("failed to query user devices") + } + if !deviceRes.UserExists { + t.Fatalf("user does not exist") + } + if l := len(deviceRes.Devices); l != 1 { + t.Fatalf("Incorrect number of user devices. Got %d, want 1", l) + } + if device.ID != deviceRes.Devices[0].ID { + t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID) + } + }) + }) +} From 90e3de322335b7687b97767f1192c28c352bcfc8 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 03:21:12 +0000 Subject: [PATCH 40/71] mas: TestAdminCheckUsernameAvailable --- clientapi/admin_test.go | 107 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 5746aec94..b957215a6 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1491,3 +1491,110 @@ func TestEventReportsGetDelete(t *testing.T) { }) }) } + +func TestAdminCheckUsernameAvailable(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed for changing the password/login + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: alice.AccountType, + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + testCases := []struct { + name string + accessToken string + userID string + wantOK bool + isAvailable bool + }{ + {name: "Missing auth", accessToken: "", wantOK: false, userID: alice.Localpart, isAvailable: false}, + {name: "Alice - user exists", accessToken: adminToken, wantOK: true, userID: alice.Localpart, isAvailable: false}, + {name: "Bob - user does not exist", accessToken: adminToken, wantOK: true, userID: "bob", isAvailable: true}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v1/username_available?username="+tc.userID) + if tc.accessToken != "" { + req.Header.Set("Authorization", "Bearer "+tc.accessToken) + } + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK || !tc.wantOK && rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + if tc.wantOK { + b := make(map[string]bool, 1) + _ = json.NewDecoder(rec.Body).Decode(&b) + available, ok := b["available"] + if !ok { + t.Fatal("'available' not found in body") + } + if available != tc.isAvailable { + t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available) + } + } + }) + } + }) +} + +func TestAdminUserDeviceRetrieveCreate(t *testing.T) { + +} + +func TestAdminUserDeviceDelete(t *testing.T) { + +} + +func TestAdminUserDevicesDelete(t *testing.T) { + +} + +func TestAdminDeactivateAccount(t *testing.T) { + +} + +func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { + +} + +func TestAdminCreateOrModifyAccount(t *testing.T) { + +} + +func TestAdminRetrieveAccount(t *testing.T) { + +} From 59f73b1ff64cb8247f44d294cd4cd27ff59fda81 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 04:00:18 +0000 Subject: [PATCH 41/71] mas: TestAdminUserDeviceRetrieveCreate --- clientapi/admin_test.go | 107 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index b957215a6..dfc1c33fc 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1514,7 +1514,6 @@ func TestAdminCheckUsernameAvailable(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - // Needed for changing the password/login userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) @@ -1572,7 +1571,113 @@ func TestAdminCheckUsernameAvailable(t *testing.T) { } func TestAdminUserDeviceRetrieveCreate(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice, bob} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Retrieve device", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices") + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + var body struct { + Total int `json:"total"` + Devices []struct { + DeviceID string `json:"device_id"` + } `json:"devices"` + } + _ = json.NewDecoder(rec.Body).Decode(&body) + if body.Total != 1 { + t.Errorf("expected 1 device, got %d", body.Total) + } + if len(body.Devices) != 1 { + t.Errorf("expected 1 device, got %d", len(body.Devices)) + } + }) + + t.Run("Create device", func(t *testing.T) { + reqBody := struct { + DeviceID string `json:"device_id"` + }{DeviceID: "devBob"} + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+bob.ID+"/devices", test.WithJSONBody(t, reqBody)) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusCreated { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusCreated, rec.Code, rec.Body.String()) + } + + var res uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: bob.ID}, &res) + if len(res.Devices) != 1 { + t.Errorf("expected 1 device, got %d", len(res.Devices)) + } + if res.Devices[0].ID != "devBob" { + t.Errorf("expected device to be devBob, got %s", res.Devices[0].ID) + } + }) + + }) } func TestAdminUserDeviceDelete(t *testing.T) { From f1de5aa838e8d25de843655c12d17adc3cb020b3 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 12:14:14 +0000 Subject: [PATCH 42/71] mas: TestAdminUserDeviceDelete --- clientapi/admin_test.go | 89 ++++++++++++++++++++++++++++++++++++++ clientapi/routing/admin.go | 2 +- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index dfc1c33fc..88c0b3632 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1681,7 +1681,96 @@ func TestAdminUserDeviceRetrieveCreate(t *testing.T) { } func TestAdminUserDeviceDelete(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/anything") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Delete existing device", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/"+deviceRes.Device.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs) + if len(rs.Devices) > 0 { + t.Errorf("expected 0 devices, got %d", len(rs.Devices)) + } + }) + + t.Run("Delete non-existing device", func(t *testing.T) { + req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+bob.ID+"/devices/anything") + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) } func TestAdminUserDevicesDelete(t *testing.T) { diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index cc073bdbd..3946964d3 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -686,7 +686,7 @@ func AdminUserDeviceDelete( } } - { + if device != nil { // XXX: this response struct can completely removed everywhere as it doesn't // have any functional purpose var res api.PerformDeviceDeletionResponse From 0db7647f46ca41a487e91e382568ffe4340af7c1 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 13:02:28 +0000 Subject: [PATCH 43/71] mas: TestAdminUserDevicesDelete --- clientapi/admin_test.go | 105 ++++++++++++++++++++++++++++++++++++- clientapi/routing/admin.go | 5 +- 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 88c0b3632..ece14a6da 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1759,7 +1759,7 @@ func TestAdminUserDeviceDelete(t *testing.T) { } }) - t.Run("Delete non-existing device", func(t *testing.T) { + t.Run("Delete non-existing user's devices", func(t *testing.T) { req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+bob.ID+"/devices/anything") req.Header.Set("Authorization", "Bearer "+adminToken) @@ -1774,7 +1774,110 @@ func TestAdminUserDeviceDelete(t *testing.T) { } func TestAdminUserDevicesDelete(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + type payload struct { + Devices []string `json:"devices"` + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Delete existing user's devices", func(t *testing.T) { + var deviceRes uapi.PerformDeviceCreationResponse + if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ + Localpart: alice.Localpart, + ServerName: cfg.Global.ServerName, + }, &deviceRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest( + t, + http.MethodPost, + "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices", + test.WithJSONBody(t, payload{Devices: []string{deviceRes.Device.ID}}), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryDevicesResponse + _ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs) + if len(rs.Devices) > 0 { + t.Errorf("expected 0 devices, got %d", len(rs.Devices)) + } + }) + + t.Run("Delete non-existing user's devices", func(t *testing.T) { + req := test.NewRequest( + t, + http.MethodPost, + "/_synapse/admin/v2/users/"+bob.ID+"/delete_devices", + test.WithJSONBody(t, payload{Devices: []string{"anyDevID"}}), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) } func TestAdminDeactivateAccount(t *testing.T) { diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 3946964d3..02940128d 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -720,11 +720,13 @@ func AdminUserDevicesDelete( } userID := vars["userID"] + if req.Body == nil { + return util.MessageResponse(http.StatusBadRequest, "body is required") + } var payload struct { Devices []string `json:"devices"` } - defer req.Body.Close() // nolint: errcheck if err = json.NewDecoder(req.Body).Decode(&payload); err != nil { logger.WithError(err).Error("unable to decode device deletion request") return util.JSONResponse{ @@ -732,6 +734,7 @@ func AdminUserDevicesDelete( JSON: spec.InternalServerError{}, } } + defer req.Body.Close() // nolint: errcheck { // XXX: this response struct can completely removed everywhere as it doesn't From 4193b7b197166b50e3719de64630df0642996937 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 13:12:13 +0000 Subject: [PATCH 44/71] mas: TestAdminDeactivateAccount --- clientapi/admin_test.go | 82 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index ece14a6da..de4297239 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1881,7 +1881,89 @@ func TestAdminUserDevicesDelete(t *testing.T) { } func TestAdminDeactivateAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Deactivate existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var rs uapi.QueryAccountByLocalpartResponse + _ = userAPI.QueryAccountByLocalpart(ctx, &uapi.QueryAccountByLocalpartRequest{Localpart: alice.Localpart, ServerName: cfg.Global.ServerName}, &rs) + if !rs.Account.Deactivated { + t.Fatalf("expected account is deactivated") + } + }) + + t.Run("Deactivate non-existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+bob.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + }) } func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { From e8902dad02696c3d79fffd108496faa8b12d4acf Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 13:30:47 +0000 Subject: [PATCH 45/71] mas: TestAdminRetrieveAccount --- clientapi/admin_test.go | 80 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index de4297239..8cf890e98 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1975,5 +1975,85 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { } func TestAdminRetrieveAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + t.Run("Retrieve existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + body := `{"display_name":"1","avatar_url":"","deactivated":false}` + if rec.Body.String() != body { + t.Fatalf("expected body %s, got %s", body, rec.Body.String()) + } + }) + + t.Run("Retrieve non-existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+bob.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected http status %d, got %d: %s", http.StatusNotFound, rec.Code, rec.Body.String()) + } + }) + }) } From 5dd8568ecde890740685c1d353afe5bcd60b52f3 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 15:41:11 +0000 Subject: [PATCH 46/71] mas: TestAdminCreateOrModifyAccount --- clientapi/admin_test.go | 172 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 8cf890e98..390e756b6 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1971,7 +1971,179 @@ func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { } func TestAdminCreateOrModifyAccount(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + + for _, u := range []*test.User{alice} { + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + type threePID struct { + Medium string `json:"medium"` + Address string `json:"address"` + } + type adminCreateOrModifyAccountRequest struct { + DisplayName string `json:"displayname"` + AvatarURL string `json:"avatar_url"` + ThreePIDs []threePID `json:"threepids"` + } + + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPut, "/_synapse/admin/v2/users/"+alice.ID) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + testCases := []struct { + User *test.User + Payload adminCreateOrModifyAccountRequest + Expected struct { + DisplayName, + AvatarURL string + ThreePIDs []string + } + Code int + NewUser bool + }{ + { + User: alice, + Payload: adminCreateOrModifyAccountRequest{ + DisplayName: "alice", + AvatarURL: "https://alice-avatar.example.com", + ThreePIDs: []threePID{ + { + Medium: "email", + Address: "alice@example.com", + }, + }, + }, + Expected: struct { + DisplayName, AvatarURL string + ThreePIDs []string + }{ + // In order to avoid any confusion and undesired behaviour, we do not change display name and avatar url if account already exists + DisplayName: "1", + AvatarURL: "", + ThreePIDs: []string{"alice@example.com"}, + }, + Code: http.StatusOK, + NewUser: false, + }, + { + User: bob, + Payload: adminCreateOrModifyAccountRequest{ + DisplayName: "bob", + AvatarURL: "https://bob-avatar.example.com", + ThreePIDs: []threePID{ + { + Medium: "email", + Address: "bob@example.com", + }, + }, + }, + Expected: struct { + DisplayName, AvatarURL string + ThreePIDs []string + }{ + DisplayName: "bob", + AvatarURL: "https://bob-avatar.example.com", + ThreePIDs: []string{"bob@example.com"}, + }, + Code: http.StatusCreated, + NewUser: true, + }, + } + + for _, tc := range testCases { + name := "" + if tc.NewUser { + name = fmt.Sprintf("Create user %s", tc.User.ID) + } else { + name = fmt.Sprintf("Modify user %s", tc.User.ID) + } + t.Run(name, func(t *testing.T) { + req := test.NewRequest( + t, + http.MethodPut, + "/_synapse/admin/v2/users/"+tc.User.ID, + test.WithJSONBody(t, tc.Payload), + ) + req.Header.Set("Authorization", "Bearer "+adminToken) + + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } + + p, _ := userAPI.QueryProfile(ctx, tc.User.ID) + if p.DisplayName != tc.Expected.DisplayName { + t.Fatalf("expected display name %s, got %s", tc.Expected.DisplayName, p.DisplayName) + } + if p.AvatarURL != tc.Expected.AvatarURL { + t.Fatalf("expected avatar_url %s, got %s", tc.Expected.AvatarURL, p.AvatarURL) + } + var threePidRs uapi.QueryThreePIDsForLocalpartResponse + _ = userAPI.QueryThreePIDsForLocalpart( + ctx, + &uapi.QueryThreePIDsForLocalpartRequest{Localpart: tc.User.Localpart, ServerName: cfg.Global.ServerName}, + &threePidRs, + ) + if len(threePidRs.ThreePIDs) != 1 { + t.Fatalf("expected 1 3pid got %d", len(threePidRs.ThreePIDs)) + } + tp := threePidRs.ThreePIDs[0] + if tp.Medium != "email" { + t.Fatalf("expected 3pid medium email got %s", tp.Medium) + } + if tp.Address != tc.Payload.ThreePIDs[0].Address { + t.Fatalf("expected 3pid address %s got %s", tc.Expected.ThreePIDs[0], tp.Address) + } + }) + + } + }) } func TestAdminRetrieveAccount(t *testing.T) { From 418c584e40785fabc48862e376c5d4ad3799613a Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 18:54:59 +0000 Subject: [PATCH 47/71] mas: TestAdminAllowCrossSigningReplacementWithoutUIA --- clientapi/admin_test.go | 119 ++++++++++++++++++++++++++++++++----- clientapi/routing/admin.go | 7 ++- 2 files changed, 111 insertions(+), 15 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 390e756b6..c2000e42f 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/element-hq/dendrite/userapi/types" + "github.com/element-hq/dendrite/federationapi" "github.com/element-hq/dendrite/internal/caching" "github.com/element-hq/dendrite/internal/httputil" @@ -1967,7 +1969,103 @@ func TestAdminDeactivateAccount(t *testing.T) { } func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { + alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + adminToken := "superSecretAdminToken" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + // add a vhost + cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}, + }) + // There's no need to add a full config for msc3861 as we need only an admin token + cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"} + cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics) + t.Run("Missing auth token", func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+alice.ID+"/_allow_cross_signing_replacement_without_uia") + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String()) + } + var b spec.MatrixError + _ = json.NewDecoder(rec.Body).Decode(&b) + if b.ErrCode != spec.ErrorMissingToken { + t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode) + } + }) + + for _, u := range []*test.User{alice} { + var userRes uapi.PerformAccountCreationResponse + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: u.Localpart, + ServerName: cfg.Global.ServerName, + Password: "", + }, &userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + _ = userAPI.KeyDatabase.StoreCrossSigningKeysForUser(ctx, alice.ID, types.CrossSigningKeyMap{ + fclient.CrossSigningKeyPurposeMaster: types.CrossSigningKey{ + KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + }, + fclient.CrossSigningKeyPurposeSelfSigning: types.CrossSigningKey{ + KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + }, + fclient.CrossSigningKeyPurposeUserSigning: types.CrossSigningKey{ + KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + }, + }, nil) + + } + + testCases := []struct { + Name string + User *test.User + Code int + }{ + {Name: "existing user", User: alice, Code: 200}, + {Name: "non-existing user", User: bob, Code: 404}, + } + + now := time.Now() + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+tc.User.ID+"/_allow_cross_signing_replacement_without_uia") + req.Header.Set("Authorization", "Bearer "+adminToken) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } + + if rec.Code == 200 { + buf := make(map[string]int64, 1) + _ = json.NewDecoder(rec.Body).Decode(&buf) + if ts := buf["updatable_without_uia_before_ms"]; ts <= now.UnixMilli() { + t.Fatalf("expected updatable_without_uia_before_ms is in future, got %d", ts) + } + } + }) + } + }) } func TestAdminCreateOrModifyAccount(t *testing.T) { @@ -2035,6 +2133,7 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { }) testCases := []struct { + Name string User *test.User Payload adminCreateOrModifyAccountRequest Expected struct { @@ -2042,10 +2141,10 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { AvatarURL string ThreePIDs []string } - Code int - NewUser bool + Code int }{ { + Name: fmt.Sprintf("Modify user %s", alice.ID), User: alice, Payload: adminCreateOrModifyAccountRequest{ DisplayName: "alice", @@ -2066,10 +2165,10 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { AvatarURL: "", ThreePIDs: []string{"alice@example.com"}, }, - Code: http.StatusOK, - NewUser: false, + Code: http.StatusOK, }, { + Name: fmt.Sprintf("Create user %s", bob.ID), User: bob, Payload: adminCreateOrModifyAccountRequest{ DisplayName: "bob", @@ -2089,19 +2188,12 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { AvatarURL: "https://bob-avatar.example.com", ThreePIDs: []string{"bob@example.com"}, }, - Code: http.StatusCreated, - NewUser: true, + Code: http.StatusCreated, }, } for _, tc := range testCases { - name := "" - if tc.NewUser { - name = fmt.Sprintf("Create user %s", tc.User.ID) - } else { - name = fmt.Sprintf("Modify user %s", tc.User.ID) - } - t.Run(name, func(t *testing.T) { + t.Run(tc.Name, func(t *testing.T) { req := test.NewRequest( t, http.MethodPut, @@ -2141,7 +2233,6 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { t.Fatalf("expected 3pid address %s got %s", tc.Expected.ThreePIDs[0], tp.Address) } }) - } }) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 02940128d..a7bd4886b 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -819,12 +819,17 @@ func AdminAllowCrossSigningReplacementWithoutUIA( } var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs) - if err != nil && err != sql.ErrNoRows { + if err != nil && !errors.Is(err, sql.ErrNoRows) { util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown(err.Error()), } + } else if errors.Is(err, sql.ErrNoRows) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("User not found."), + } } return util.JSONResponse{ Code: http.StatusOK, From 3619a6de8dd0d2d6ec0331da879f9f9cf272d4dc Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 15 Jan 2025 19:18:29 +0000 Subject: [PATCH 48/71] mas: refactoring --- clientapi/admin_test.go | 60 +++++---- .../msc3861/msc3861_user_verifier_test.go | 114 +++++++++--------- 2 files changed, 95 insertions(+), 79 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index c2000e42f..2d6cc1086 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -2161,7 +2161,7 @@ func TestAdminCreateOrModifyAccount(t *testing.T) { ThreePIDs []string }{ // In order to avoid any confusion and undesired behaviour, we do not change display name and avatar url if account already exists - DisplayName: "1", + DisplayName: alice.Localpart, AvatarURL: "", ThreePIDs: []string{"alice@example.com"}, }, @@ -2291,32 +2291,42 @@ func TestAdminRetrieveAccount(t *testing.T) { } }) - t.Run("Retrieve existing account", func(t *testing.T) { - req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID) - req.Header.Set("Authorization", "Bearer "+adminToken) + testCase := []struct { + Name string + User *test.User + Code int + Body string + }{ + { + Name: "Retrieve existing account", + User: alice, + Code: http.StatusOK, + Body: fmt.Sprintf(`{"display_name":"%s","avatar_url":"","deactivated":false}`, alice.Localpart), + }, + { + Name: "Retrieve non-existing account", + User: bob, + Code: http.StatusNotFound, + Body: "", + }, + } - rec := httptest.NewRecorder() - routers.SynapseAdmin.ServeHTTP(rec, req) - t.Logf("%s", rec.Body.String()) - if rec.Code != http.StatusOK { - t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) - } - body := `{"display_name":"1","avatar_url":"","deactivated":false}` - if rec.Body.String() != body { - t.Fatalf("expected body %s, got %s", body, rec.Body.String()) - } - }) + for _, tc := range testCase { + t.Run("Retrieve existing account", func(t *testing.T) { + req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+tc.User.ID) + req.Header.Set("Authorization", "Bearer "+adminToken) - t.Run("Retrieve non-existing account", func(t *testing.T) { - req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+bob.ID) - req.Header.Set("Authorization", "Bearer "+adminToken) + rec := httptest.NewRecorder() + routers.SynapseAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if rec.Code != tc.Code { + t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String()) + } - rec := httptest.NewRecorder() - routers.SynapseAdmin.ServeHTTP(rec, req) - t.Logf("%s", rec.Body.String()) - if rec.Code != http.StatusNotFound { - t.Fatalf("expected http status %d, got %d: %s", http.StatusNotFound, rec.Code, rec.Body.String()) - } - }) + if tc.Body != "" && tc.Body != rec.Body.String() { + t.Fatalf("expected body %s, got %s", tc.Body, rec.Body.String()) + } + }) + } }) } diff --git a/setup/mscs/msc3861/msc3861_user_verifier_test.go b/setup/mscs/msc3861/msc3861_user_verifier_test.go index 8681decdd..0bfe68e7f 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier_test.go +++ b/setup/mscs/msc3861/msc3861_user_verifier_test.go @@ -31,65 +31,74 @@ var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerS return &statistics.ServerStatistics{}, nil } -type roundTripper struct{} +type roundTripper struct { + roundTrip func(request *http.Request) (*http.Response, error) +} func (rt *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - var ( - respBody string - statusCode int - ) - - switch request.URL.String() { - case "https://mas.example.com/.well-known/openid-configuration": - respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}` - statusCode = http.StatusOK - case "https://mas.example.com/oauth2/introspect": - _ = request.ParseForm() - - switch request.Form.Get("token") { - case "validTokenUserExistsTokenActive": - statusCode = http.StatusOK - resp := introspectionResponse{ - Active: true, - Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*", - Sub: "111111111111111111", - Username: "1", - } - b, _ := json.Marshal(resp) - respBody = string(b) - case "validTokenUserDoesNotExistTokenActive": - statusCode = http.StatusOK - resp := introspectionResponse{ - Active: true, - Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*", - Sub: "222222222222222222", - Username: "2", - } - b, _ := json.Marshal(resp) - respBody = string(b) - case "validTokenUserExistsTokenInactive": + return rt.roundTrip(request) +} + +func TestVerifyUserFromRequest(t *testing.T) { + aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + + roundTrip := func(request *http.Request) (*http.Response, error) { + var ( + respBody string + statusCode int + ) + + switch request.URL.String() { + case "https://mas.example.com/.well-known/openid-configuration": + respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}` statusCode = http.StatusOK - resp := introspectionResponse{Active: false} - b, _ := json.Marshal(resp) - respBody = string(b) - default: - return nil, errors.New("Request URL not supported by stub") + case "https://mas.example.com/oauth2/introspect": + _ = request.ParseForm() + + switch request.Form.Get("token") { + case "validTokenUserExistsTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "111111111111111111", + Username: aliceUser.Localpart, + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserDoesNotExistTokenActive": + statusCode = http.StatusOK + resp := introspectionResponse{ + Active: true, + Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*", + Sub: "222222222222222222", + Username: bobUser.Localpart, + } + b, _ := json.Marshal(resp) + respBody = string(b) + case "validTokenUserExistsTokenInactive": + statusCode = http.StatusOK + resp := introspectionResponse{Active: false} + b, _ := json.Marshal(resp) + respBody = string(b) + default: + return nil, errors.New("Request URL not supported by stub") + } } - } - respReader := io.NopCloser(strings.NewReader(respBody)) - resp := http.Response{ - StatusCode: statusCode, - Body: respReader, - ContentLength: int64(len(respBody)), - Header: map[string][]string{"Content-Type": {"application/json"}}, + respReader := io.NopCloser(strings.NewReader(respBody)) + resp := http.Response{ + StatusCode: statusCode, + Body: respReader, + ContentLength: int64(len(respBody)), + Header: map[string][]string{"Content-Type": {"application/json"}}, + } + return &resp, nil } - return &resp, nil -} -func TestVerifyUserFromRequest(t *testing.T) { httpClient := http.Client{ - Transport: &roundTripper{}, + Transport: &roundTripper{roundTrip: roundTrip}, } ctx := context.Background() @@ -123,9 +132,6 @@ func TestVerifyUserFromRequest(t *testing.T) { } u, _ := url.Parse("https://example.com/something") - aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) - bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) - t.Run("existing user and active token", func(t *testing.T) { localpart, serverName, _ := gomatrixserverlib.SplitID('@', aliceUser.ID) userRes := &uapi.PerformAccountCreationResponse{} From 64f308b55cf708c3e28ffe6d69b3c84cc4180df0 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 17 Jan 2025 00:09:42 +0000 Subject: [PATCH 49/71] mas: add missing server_name field to sqlite migration --- .../deltas/2024123101150000_drop_primary_key_constraint.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go index def7a75e2..3758819bd 100644 --- a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go +++ b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go @@ -24,9 +24,9 @@ func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error { ); INSERT INTO userapi_devices ( - access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent + access_token, session_id, device_id, localpart, server_name, created_ts, display_name, last_seen_ts, ip, user_agent ) SELECT - access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', '' + access_token, session_id, device_id, localpart, server_name, created_ts, display_name, created_ts, '', '' FROM userapi_devices_tmp; DROP TABLE userapi_devices_tmp;`) if err != nil { From b44f899637b7255623f9895d7ef2f8e5b3d871f4 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 17 Jan 2025 02:56:17 +0000 Subject: [PATCH 50/71] mas: cross signing fixes after merge --- clientapi/routing/key_crosssigning.go | 145 +++++++++++++++----------- 1 file changed, 87 insertions(+), 58 deletions(-) diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 2a1321f68..f287e31f3 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -9,10 +9,9 @@ package routing import ( "context" "net/http" + "strings" "time" - "github.com/sirupsen/logrus" - "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/clientapi/auth/authtypes" "github.com/element-hq/dendrite/clientapi/httputil" @@ -32,6 +31,7 @@ type crossSigningRequest struct { type UploadKeysAPI interface { QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) + QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) api.UploadDeviceKeysAPI } @@ -40,6 +40,7 @@ func UploadCrossSigningDeviceKeys( keyserverAPI UploadKeysAPI, device *api.Device, accountAPI auth.GetAccountByPassword, cfg *config.ClientAPI, ) util.JSONResponse { + logger := util.GetLogger(req.Context()) uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -48,6 +49,11 @@ func UploadCrossSigningDeviceKeys( return *resErr } + sessionID := uploadReq.Auth.Session + if sessionID == "" { + sessionID = util.RandomString(sessionIDLength) + } + // Query existing keys to determine if UIA is required keyResp := api.QueryKeysResponse{} keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ @@ -57,78 +63,101 @@ func UploadCrossSigningDeviceKeys( }, &keyResp) if keyResp.Error != nil { - logrus.WithError(keyResp.Error).Error("Failed to query keys") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(keyResp.Error.Error()), - } + logger.WithError(keyResp.Error).Error("Failed to query keys") + return convertKeyError(keyResp.Error) } existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID] - requireUIA := false - if hasMasterKey { - // If we have a master key, check if any of the existing keys differ. If they do, - // we need to re-authenticate the user. - requireUIA = keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) - } + requireUIA := true - if requireUIA { - sessionID := uploadReq.Auth.Session - if sessionID == "" { - sessionID = util.RandomString(sessionIDLength) - } - if uploadReq.Auth.Type != authtypes.LoginTypePassword { + if hasMasterKey { + if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) { + // If we have a master key, check if any of the existing keys differ. If they don't + // we return 200 as keys are still valid and there's nothing to do. return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, - }, - nil, - ), + Code: http.StatusOK, + JSON: struct{}{}, } } - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI, - Config: cfg, - } - if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { - return *authErr - } - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) - } - uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + // With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable. + if cfg.MSCs.MSC3861Enabled() { + masterKeyResp := api.QueryMasterKeysResponse{} + keyserverAPI.QueryMasterKeys(req.Context(), &api.QueryMasterKeysRequest{UserID: device.UserID}, &masterKeyResp) - if err := uploadRes.Error; err != nil { - switch { - case err.IsInvalidSignature: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidSignature(err.Error()), + if masterKeyResp.Error != nil { + logger.WithError(masterKeyResp.Error).Error("Failed to query master key") + return convertKeyError(masterKeyResp.Error) } - case err.IsMissingParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.MissingParam(err.Error()), + if k := masterKeyResp.Key; k != nil && k.UpdatableWithoutUIABeforeMs != nil { + requireUIA = !(time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs) } - case err.IsInvalidParam: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidParam(err.Error()), + + if requireUIA { + url := "" + if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { + url = strings.Join([]string{m.AccountManagementURL, "?action=", CrossSigningResetStage}, "") + } else { + url = m.Issuer + } + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + "dummy", + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{CrossSigningResetStage}, + }, + }, + map[string]interface{}{ + CrossSigningResetStage: map[string]string{ + "url": url, + }, + }, + strings.Join([]string{ + "To reset your end-to-end encryption cross-signing identity, you first need to approve it at", + url, + "and then try again.", + }, " "), + ), + } } - default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown(err.Error()), + // XXX: is it necessary? + sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage) + } else { + if uploadReq.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + "", + ), + } + } + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountAPI, + Config: cfg, } + if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { + return *authErr + } + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) } } + uploadReq.UserID = device.UserID + keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + + if err := uploadRes.Error; err != nil { + return convertKeyError(err) + } + return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, From 021431c710f980728f1cf8cb8d083534ccd5d59f Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 17 Jan 2025 03:24:07 +0000 Subject: [PATCH 51/71] mas: fix key_crosssigning_test.go --- clientapi/routing/key_crosssigning_test.go | 75 +++++++++++++++++----- setup/config/config_mscs.go | 4 +- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 0ebb91e07..1339b45ce 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -10,6 +10,8 @@ import ( "strings" "testing" + "github.com/element-hq/dendrite/userapi/types" + "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" @@ -20,19 +22,28 @@ import ( ) type mockKeyAPI struct { - t *testing.T - userResponses map[string]api.QueryKeysResponse + t *testing.T + queryKeysData map[string]api.QueryKeysResponse + queryMasterKeysData map[string]api.QueryMasterKeysResponse } func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { - res.MasterKeys = m.userResponses[req.UserID].MasterKeys - res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys - res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys + res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys + res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys + res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys if m.t != nil { m.t.Logf("QueryKeys: %+v => %+v", req, res) } } +func (m mockKeyAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { + res.Key = m.queryMasterKeysData[req.UserID].Key + res.Error = m.queryMasterKeysData[req.UserID].Error + if m.t != nil { + m.t.Logf("QueryMasterKeys: %+v => %+v", req, res) + } +} + func (m mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Just a dummy upload which always succeeds } @@ -53,13 +64,19 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ + "@user:example.com": {}, + }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ "@user:example.com": {}, }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code) @@ -101,18 +118,32 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) { keyserverAPI := &mockKeyAPI{ t: t, - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}}, }, SelfSigningKeys: nil, UserSigningKeys: nil, }, }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: &types.CrossSigningKey{ + KeyData: spec.Base64Bytes("key1"), + }, + }, + }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusUnauthorized { @@ -132,8 +163,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) { keyserverAPI := &mockKeyAPI{} device := &api.Device{UserID: "@user:example.com", ID: "device"} - cfg := &config.ClientAPI{} - + cfg := &config.ClientAPI{ + MSCs: &config.MSCs{ + MSCs: []string{}, + }, + } res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg) if res.Code != http.StatusBadRequest { t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code) @@ -151,10 +185,21 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) { req.Header.Set("Content-Type", "application/json") keyserverAPI := &mockKeyAPI{ - userResponses: map[string]api.QueryKeysResponse{ + queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": { MasterKeys: map[string]fclient.CrossSigningKey{ - "@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}}, + "@user:example.com": { + UserID: "@user:example.com", + Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster}, + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}, + }, + }, + }, + }, + queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ + "@user:example.com": { + Key: &types.CrossSigningKey{ + KeyData: spec.Base64Bytes("different_key"), }, }, }, diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index 694bb3513..c8f2249a1 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -1,7 +1,5 @@ package config -import "slices" - type MSCs struct { Matrix *Global `yaml:"-"` @@ -46,7 +44,7 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) { } func (c *MSCs) MSC3861Enabled() bool { - return slices.Contains(c.MSCs, "msc3861") && c.MSC3861 != nil + return c.Enabled("msc3861") && c.MSC3861 != nil } type MSC3861 struct { From 641f5b54f8e1ebc7ff3d08bf45bc9b42ed5454d1 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 17 Jan 2025 03:45:25 +0000 Subject: [PATCH 52/71] mas: todo comment --- clientapi/routing/key_crosssigning_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 1339b45ce..2d0893626 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -21,6 +21,8 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) +// TODO: add more tests to cover cases related to MSC3861 + type mockKeyAPI struct { t *testing.T queryKeysData map[string]api.QueryKeysResponse From 17b7677071a608abb7afa4a9c4fbac9bb3affa56 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 22 Jan 2025 22:50:37 +0000 Subject: [PATCH 53/71] fix typo in api.QueryAccessTokenAPI --- clientapi/auth/default_user_verifier.go | 2 +- userapi/api/api.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go index f0a48f518..6e6746d11 100644 --- a/clientapi/auth/default_user_verifier.go +++ b/clientapi/auth/default_user_verifier.go @@ -11,7 +11,7 @@ import ( // DefaultUserVerifier implements UserVerifier interface type DefaultUserVerifier struct { - UserAPI api.QueryAcccessTokenAPI + UserAPI api.QueryAccessTokenAPI } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/userapi/api/api.go b/userapi/api/api.go index 08308bc3b..8a60e8f06 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -48,7 +48,7 @@ type RoomserverUserAPI interface { // api functions required by the media api type MediaUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI } // api functions required by the federation api @@ -65,7 +65,7 @@ type FederationUserAPI interface { // api functions required by the sync api type SyncUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error @@ -76,7 +76,7 @@ type SyncUserAPI interface { // api functions required by the client api type ClientUserAPI interface { - QueryAcccessTokenAPI + QueryAccessTokenAPI LoginTokenInternalAPI UserLoginAPI ClientKeyAPI @@ -132,9 +132,8 @@ type QuerySearchProfilesAPI interface { QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error } -// FIXME: typo in Acccess // common function for creating authenticated endpoints (used in client/media/sync api) -type QueryAcccessTokenAPI interface { +type QueryAccessTokenAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error } From 8a05a66cd74a7437db5a9a2c18c07f242478ffbe Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 23 Jan 2025 01:25:04 +0000 Subject: [PATCH 54/71] code review fixes --- clientapi/routing/routing.go | 2 +- setup/config/config_clientapi.go | 6 ++--- setup/config/config_mscs.go | 3 ++- syncapi/syncapi_test.go | 22 +++++++++---------- userapi/api/api.go | 2 +- userapi/storage/postgres/accounts_table.go | 3 +-- .../postgres/localpart_external_ids_table.go | 4 ++-- userapi/storage/sqlite3/accounts_table.go | 2 +- .../sqlite3/localpart_external_ids_table.go | 4 ++-- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2a8e8427b..15a5addfb 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -376,7 +376,7 @@ func Setup( })).Methods(http.MethodPost) } else { // If msc3861 is enabled, these endpoints are either redundant or replaced by Matrix Auth Service (MAS) - // Once we migrate to MAS completely, these edndpoints should be removed + // Once we migrate to MAS completely, these endpoints should be removed v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req, nil); r != nil { diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 4f6a7ac20..3e683059c 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -76,10 +76,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors) { c.RateLimiting.Verify(configErrs) if c.MSCs.MSC3861Enabled() { - if c.RecaptchaEnabled || !c.RegistrationDisabled { + if !c.RegistrationDisabled || c.RecaptchaEnabled { configErrs.Add( - "You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC." + - "As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled." + + "You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC. " + + "As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled. " + "You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.", ) } diff --git a/setup/config/config_mscs.go b/setup/config/config_mscs.go index c8f2249a1..fb0c547fe 100644 --- a/setup/config/config_mscs.go +++ b/setup/config/config_mscs.go @@ -10,7 +10,8 @@ type MSCs struct { // 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836 MSCs []string `yaml:"mscs"` - // MSC3861 contains config related to the experimental feature MSC3861. It takes effect only if 'msc3861' is included in 'MSCs' array + // MSC3861 contains config related to the experimental feature MSC3861. + // It takes effect only if 'msc3861' is included in 'MSCs' array. MSC3861 *MSC3861 `yaml:"msc3861,omitempty"` Database DatabaseOptions `yaml:"database,omitempty"` diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 88db32083..efd283826 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -121,14 +121,14 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe } type userVerifier struct { - m map[string]struct { + accessTokenToDeviceAndResponse map[string]struct { Device *userapi.Device Response *util.JSONResponse } } func (u *userVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { - if pair, ok := u.m[req.URL.Query().Get("access_token")]; ok { + if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok { return pair.Device, pair.Response } return nil, nil @@ -212,13 +212,13 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { }, } - uv.m = make(map[string]struct { + uv.accessTokenToDeviceAndResponse = make(map[string]struct { Device *userapi.Device Response *util.JSONResponse }, len(testCases)) for _, tc := range testCases { - uv.m[tc.req.URL.Query().Get("access_token")] = struct { + uv.accessTokenToDeviceAndResponse[tc.req.URL.Query().Get("access_token")] = struct { Device *userapi.Device Response *util.JSONResponse }{Device: tc.device, Response: tc.response} @@ -285,7 +285,7 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -539,7 +539,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -669,7 +669,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -947,7 +947,7 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -1024,7 +1024,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { defer close() natsInstance := jetstream.NATSInstance{} uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -1258,7 +1258,7 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI.SetFederationAPI(nil, nil) uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ @@ -1446,7 +1446,7 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) uv := userVerifier{ - m: map[string]struct { + accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse }{ diff --git a/userapi/api/api.go b/userapi/api/api.go index 8a60e8f06..334b6a1ed 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -483,7 +483,7 @@ type LocalpartExternalID struct { Localpart string ExternalID string AuthProvider string - CreatedTs int64 + CreatedTS int64 } // UserInfo is for returning information about the user an OpenID token was issued for diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 5c0519962..19c16230b 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -116,8 +116,7 @@ func (s *accountsStatements) InsertAccount( localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { - // TODO: can we replace "UnixNano() / 1M" with "UnixMilli()"? - createdTimeMS := time.Now().UnixNano() / 1000000 + createdTimeMS := spec.AsTimestamp(time.Now()) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go index 9bc47dbf6..e0f0d34af 100644 --- a/userapi/storage/postgres/localpart_external_ids_table.go +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -38,7 +38,7 @@ const selectUserExternalIDSQL = "" + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" const deleteUserExternalIDSQL = "" + - "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + "DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" type localpartExternalIDStatements struct { db *sql.DB @@ -69,7 +69,7 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, AuthProvider: authProvider, } err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( - &ret.Localpart, &ret.CreatedTs, + &ret.Localpart, &ret.CreatedTS, ) if err != nil { if err == sql.ErrNoRows { diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 7c6279196..1090ec3ed 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { - createdTimeMS := time.Now().UnixNano() / 1000000 + createdTimeMS := spec.AsTimestamp(time.Now()) stmt := s.insertAccountStmt var err error diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go index 30f1fc60e..43ae2619c 100644 --- a/userapi/storage/sqlite3/localpart_external_ids_table.go +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -38,7 +38,7 @@ const selectLocalpartExternalIDSQL = "" + "SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" const deleteLocalpartExternalIDSQL = "" + - "SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" + "DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2" type localpartExternalIDStatements struct { db *sql.DB @@ -69,7 +69,7 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, AuthProvider: authProvider, } err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan( - &ret.Localpart, &ret.CreatedTs, + &ret.Localpart, &ret.CreatedTS, ) if err != nil { if err == sql.ErrNoRows { From bf31c4429876a8062fa25c534ba10bd1301a8a78 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 23 Jan 2025 01:58:49 +0000 Subject: [PATCH 55/71] more fixes --- clientapi/auth/default_user_verifier.go | 7 +++--- setup/mscs/msc3861/msc3861.go | 4 ++- setup/mscs/msc3861/msc3861_user_verifier.go | 28 +++++++++++++++------ 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go index 6e6746d11..54147b772 100644 --- a/clientapi/auth/default_user_verifier.go +++ b/clientapi/auth/default_user_verifier.go @@ -20,7 +20,8 @@ type DefaultUserVerifier struct { // Note: For an AS user, AS dummy device is returned. // On failure returns an JSON error response which can be sent to the client. func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) { - util.GetLogger(req.Context()).Debug("Default VerifyUserFromRequest") + ctx := req.Context() + util.GetLogger(ctx).Debug("Default VerifyUserFromRequest") // Try to find the Application Service user token, err := ExtractAccessToken(req) if err != nil { @@ -30,12 +31,12 @@ func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Dev } } var res api.QueryAccessTokenResponse - err = d.UserAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{ + err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{ AccessToken: token, AppServiceUserID: req.URL.Query().Get("user_id"), }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") + util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed") return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index b3c458b1b..8a0df647d 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -2,13 +2,15 @@ package msc3861 import ( "github.com/element-hq/dendrite/setup" + "github.com/matrix-org/gomatrixserverlib/fclient" ) func Enable(m *setup.Monolith) error { + client := fclient.NewClient() userVerifier, err := newMSC3861UserVerifier( m.UserAPI, m.Config.Global.ServerName, m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled, - nil, + client, ) if err != nil { return err diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index fbfdcaa24..65cf59575 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -1,7 +1,6 @@ package msc3861 import ( - "cmp" "context" "database/sql" "encoding/json" @@ -15,6 +14,7 @@ import ( "github.com/element-hq/dendrite/clientapi/auth" "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -45,7 +45,7 @@ type MSC3861UserVerifier struct { userAPI api.UserInternalAPI serverName spec.ServerName cfg *config.MSC3861 - httpClient *http.Client + httpClient *fclient.Client openIdConfig *OpenIDConfiguration allowGuest bool } @@ -55,13 +55,21 @@ func newMSC3861UserVerifier( serverName spec.ServerName, cfg *config.MSC3861, allowGuest bool, - httpClient *http.Client, + client *fclient.Client, ) (*MSC3861UserVerifier, error) { - client := cmp.Or(httpClient, http.DefaultClient) + if cfg == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil") + } + + if client == nil { + return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil") + } + openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer) if err != nil { return nil, err } + return &MSC3861UserVerifier{ userAPI: userAPI, serverName: serverName, @@ -342,14 +350,14 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret) - resp, err := m.httpClient.Do(req) + resp, err := m.httpClient.DoHTTPRequest(ctx, req) if err != nil { return nil, err } body := resp.Body defer resp.Body.Close() // nolint: errcheck - if c := resp.StatusCode; c < 200 || c >= 300 { + if c := resp.StatusCode; c/100 != 2 { return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) } var ir introspectionResponse @@ -394,13 +402,17 @@ type OpenIDConfiguration struct { AccountManagementActionsSupported []string `json:"account_management_actions_supported"` } -func fetchOpenIDConfiguration(httpClient *http.Client, authHostURL string) (*OpenIDConfiguration, error) { +func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) { u, err := url.Parse(authHostURL) if err != nil { return nil, err } u = u.JoinPath(".well-known/openid-configuration") - resp, err := httpClient.Get(u.String()) + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + resp, err := httpClient.DoHTTPRequest(context.Background(), req) if err != nil { return nil, err } From a185027fda66be6276b49aff61393e915969fcc2 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 23 Jan 2025 02:45:56 +0000 Subject: [PATCH 56/71] cr fixes --- clientapi/routing/admin.go | 15 ++++----------- internal/httputil/httpapi.go | 3 +++ setup/mscs/msc3861/msc3861_user_verifier.go | 3 +-- syncapi/syncapi_test.go | 20 ++++++++++---------- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index a7bd4886b..f92a74dba 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -551,12 +551,6 @@ func AdminUserDeviceRetrieveCreate( switch req.Method { case http.MethodPost: - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.InvalidParam(userID), - } - } var payload struct { DeviceID string `json:"device_id"` } @@ -980,11 +974,10 @@ func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI user Code: http.StatusNotFound, JSON: spec.NotFound(err.Error()), } - } else if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown(err.Error()), - } + } + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), } } body.AvatarURL = profile.AvatarURL diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index f04c2bd4f..65a2db2e0 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -58,6 +58,9 @@ func WithAuth() AuthAPIOption { } } +// UserVerifier verifies users by their access tokens. Currently, there are two interface implementations: +// DefaultUserVerifier and MSC3861UserVerifier. The first one checks if the token exists in the server's database, +// whereas the latter passes the token for verification to MAS and acts in accordance with MAS's response. type UserVerifier interface { // VerifyUserFromRequest authenticates the HTTP request, // on success returns Device of the requester. diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 65cf59575..1e203f278 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -354,14 +354,13 @@ func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) if err != nil { return nil, err } - body := resp.Body defer resp.Body.Close() // nolint: errcheck if c := resp.StatusCode; c/100 != 2 { return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, "")) } var ir introspectionResponse - if err := json.NewDecoder(body).Decode(&ir); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&ir); err != nil { return nil, err } return &ir, nil diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index efd283826..f6c0c898a 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -120,14 +120,14 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe return nil } -type userVerifier struct { +type mockUserVerifier struct { accessTokenToDeviceAndResponse map[string]struct { Device *userapi.Device Response *util.JSONResponse } } -func (u *userVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { +func (u *mockUserVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok { return pair.Device, pair.Response } @@ -161,7 +161,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) msgs := toNATSMsgs(t, cfg, room.Events()...) - uv := &userVerifier{} + uv := &mockUserVerifier{} AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, uv, caching.DisableMetrics) testrig.MustPublishMsgs(t, jsctx, msgs...) @@ -284,7 +284,7 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -538,7 +538,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -668,7 +668,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { // Use the actual internal roomserver API rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -946,7 +946,7 @@ func TestGetMembership(t *testing.T) { // Use an actual roomserver for this rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -1023,7 +1023,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() natsInstance := jetstream.NATSInstance{} - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -1257,7 +1257,7 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse @@ -1445,7 +1445,7 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - uv := userVerifier{ + uv := mockUserVerifier{ accessTokenToDeviceAndResponse: map[string]struct { Device *userapi.Device Response *util.JSONResponse From b5f34dfe47db6825867adb87269bf88022ed23d0 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Thu, 23 Jan 2025 03:10:56 +0000 Subject: [PATCH 57/71] fix test --- setup/mscs/msc3861/msc3861_user_verifier_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup/mscs/msc3861/msc3861_user_verifier_test.go b/setup/mscs/msc3861/msc3861_user_verifier_test.go index 0bfe68e7f..95b4d0e9d 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier_test.go +++ b/setup/mscs/msc3861/msc3861_user_verifier_test.go @@ -97,9 +97,9 @@ func TestVerifyUserFromRequest(t *testing.T) { return &resp, nil } - httpClient := http.Client{ - Transport: &roundTripper{roundTrip: roundTrip}, - } + httpClient := fclient.NewClient( + fclient.WithTransport(&roundTripper{roundTrip: roundTrip}), + ) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -125,7 +125,7 @@ func TestVerifyUserFromRequest(t *testing.T) { cfg.Global.ServerName, cfg.MSCs.MSC3861, false, - &httpClient, + httpClient, ) if err != nil { t.Fatal(err.Error()) From 453445695c2cfad2cd674605b7d520c4058cb501 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Fri, 24 Jan 2025 23:22:02 +0000 Subject: [PATCH 58/71] mas: store crossSigngingKeysReplacement period in sessionsDict struct instead of db --- clientapi/routing/admin.go | 25 +----------- clientapi/routing/key_crosssigning.go | 18 ++++----- clientapi/routing/register.go | 57 ++++++++++++++++++++++++--- 3 files changed, 61 insertions(+), 39 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index f92a74dba..799ece779 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -35,10 +35,6 @@ import ( "github.com/element-hq/dendrite/userapi/storage/shared" ) -const ( - replacementPeriod time.Duration = 10 * time.Minute -) - var ( validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") deviceDisplayName = "OIDC-native client" @@ -807,27 +803,10 @@ func AdminAllowCrossSigningReplacementWithoutUIA( switch req.Method { case http.MethodPost: - rq := userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest{ - UserID: userID.String(), - Duration: replacementPeriod, - } - var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse - err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown(err.Error()), - } - } else if errors.Is(err, sql.ErrNoRows) { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("User not found."), - } - } + ts := sessions.allowCrossSigningKeysReplacement(userID.String()) return util.JSONResponse{ Code: http.StatusOK, - JSON: map[string]int64{"updatable_without_uia_before_ms": rs.Timestamp}, + JSON: map[string]int64{"updatable_without_uia_before_ms": ts}, } default: return util.JSONResponse{ diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index f287e31f3..a0f7f06e1 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -49,11 +49,6 @@ func UploadCrossSigningDeviceKeys( return *resErr } - sessionID := uploadReq.Auth.Session - if sessionID == "" { - sessionID = util.RandomString(sessionIDLength) - } - // Query existing keys to determine if UIA is required keyResp := api.QueryKeysResponse{} keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ @@ -68,7 +63,6 @@ func UploadCrossSigningDeviceKeys( } existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID] - requireUIA := true if hasMasterKey { if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) { @@ -89,10 +83,8 @@ func UploadCrossSigningDeviceKeys( logger.WithError(masterKeyResp.Error).Error("Failed to query master key") return convertKeyError(masterKeyResp.Error) } - if k := masterKeyResp.Key; k != nil && k.UpdatableWithoutUIABeforeMs != nil { - requireUIA = !(time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs) - } + requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil if requireUIA { url := "" if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { @@ -122,9 +114,13 @@ func UploadCrossSigningDeviceKeys( ), } } - // XXX: is it necessary? - sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage) + sessions.restrictCrossSigningKeysReplacement(device.UserID) } else { + sessionID := uploadReq.Auth.Session + if sessionID == "" { + sessionID = util.RandomString(sessionIDLength) + } + if uploadReq.Auth.Type != authtypes.LoginTypePassword { return util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 7bcda2069..74140d4ca 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -66,11 +66,17 @@ type sessionsDict struct { // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. deleteSessionToDeviceID map[string]string + // allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating + // cross-signing keys without UIA. + allowedForCrossSigningKeysReplacement map[string]*time.Timer } // defaultTimeout is the timeout used to clean up sessions const defaultTimeOut = time.Minute * 5 +// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA +const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10 + // getCompletedStages returns the completed stages for a session. func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { d.RLock() @@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) { } } +func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { + d.Lock() + defer d.Unlock() + ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli() + t, ok := d.allowedForCrossSigningKeysReplacement[userID] + if ok { + t.Reset(allowedForCrossSigningKeysReplacementDuration) + return ts + } + d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc( + allowedForCrossSigningKeysReplacementDuration, + func() { + d.restrictCrossSigningKeysReplacement(userID) + }, + ) + return ts +} + +func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { + d.RLock() + defer d.RUnlock() + _, ok := d.allowedForCrossSigningKeysReplacement[userID] + return ok +} + +func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { + d.Lock() + defer d.Unlock() + t, ok := d.allowedForCrossSigningKeysReplacement[userID] + if ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.allowedForCrossSigningKeysReplacement, userID) + } +} + func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - sessionCompletedResult: make(map[string]registerResponse), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), - deleteSessionToDeviceID: make(map[string]string), + sessions: make(map[string][]authtypes.LoginType), + sessionCompletedResult: make(map[string]registerResponse), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), + allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer), } } From 0b4cf3baddfc5de34edbf2d678db26f6db115ad7 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 25 Jan 2025 00:38:24 +0000 Subject: [PATCH 59/71] mas: revert cross_signing_keys.updatable_without_uia_before_ms field and related logic --- userapi/api/api.go | 16 +---- userapi/internal/cross_signing.go | 26 ++------ userapi/internal/key_api.go | 2 +- userapi/storage/interface.go | 4 +- .../postgres/cross_signing_keys_table.go | 63 ++++-------------- ...signing_updatable_without_uia_before_ms.go | 23 ------- userapi/storage/shared/storage.go | 20 ++---- .../sqlite3/cross_signing_keys_table.go | 66 ++++--------------- ...signing_updatable_without_uia_before_ms.go | 23 ------- userapi/storage/tables/interface.go | 3 +- userapi/types/storage.go | 7 +- 11 files changed, 41 insertions(+), 212 deletions(-) delete mode 100644 userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go delete mode 100644 userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go diff --git a/userapi/api/api.go b/userapi/api/api.go index 334b6a1ed..9b1319986 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -689,11 +689,6 @@ type ClientKeyAPI interface { QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryMasterKeys(ctx context.Context, req *QueryMasterKeysRequest, res *QueryMasterKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA( - ctx context.Context, - req *PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest, - res *PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse, - ) error PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages @@ -922,15 +917,6 @@ type PerformUploadDeviceKeysResponse struct { Error *KeyError } -type PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest struct { - UserID string - Duration time.Duration -} - -type PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse struct { - Timestamp int64 -} - type PerformUploadDeviceSignaturesRequest struct { Signatures map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice // The user that uploaded the sig, should be populated by the clientapi. @@ -968,7 +954,7 @@ type QueryMasterKeysRequest struct { } type QueryMasterKeysResponse struct { - Key *types.CrossSigningKey + Key spec.Base64Bytes // Set if there was a fatal error processing this query Error *KeyError } diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index b638f59c0..7b245a691 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -96,16 +96,6 @@ func sanityCheckKey(key fclient.CrossSigningKey, userID string, purpose fclient. return nil } -func (a *UserInternalAPI) PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA( - ctx context.Context, - req *api.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest, - res *api.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse, -) error { - var err error - res.Timestamp, err = a.KeyDatabase.UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx, req.UserID, req.Duration) - return err -} - // nolint:gocyclo func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Find the keys to store. @@ -124,9 +114,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeMaster] = req.MasterKey for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeMaster] = types.CrossSigningKey{ - KeyData: key, - } + toStore[fclient.CrossSigningKeyPurposeMaster] = key } hasMasterKey = true } @@ -142,9 +130,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeSelfSigning] = types.CrossSigningKey{ - KeyData: key, - } + toStore[fclient.CrossSigningKeyPurposeSelfSigning] = key } } @@ -159,9 +145,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. byPurpose[fclient.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[fclient.CrossSigningKeyPurposeUserSigning] = types.CrossSigningKey{ - KeyData: key, - } + toStore[fclient.CrossSigningKeyPurposeUserSigning] = key } } @@ -214,7 +198,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. changed = true break } - if !bytes.Equal(old.KeyData, new.KeyData) { + if !bytes.Equal(old, new) { // One of the existing keys for a purpose we already knew about has // changed. changed = true @@ -226,7 +210,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. } // Store the keys. - if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore, nil); err != nil { + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 7ffaf10d7..eb7597ab9 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -243,7 +243,7 @@ func (a *UserInternalAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMas return } if key, ok := crossSigningKeyMap[fclient.CrossSigningKeyPurposeMaster]; ok { - res.Key = &key + res.Key = key } } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 2c0b4bf2a..3cf7e7659 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -10,7 +10,6 @@ import ( "context" "encoding/json" "errors" - "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -231,9 +230,8 @@ type KeyDatabase interface { CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap, updatableWithoutUIABeforeMs *int64) error + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error - UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, userID string, duration time.Duration) (int64, error) DeleteStaleDeviceLists( ctx context.Context, diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 61b15b884..38e6a3e26 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -10,9 +10,6 @@ import ( "context" "database/sql" "fmt" - "time" - - "github.com/element-hq/dendrite/userapi/storage/postgres/deltas" "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" @@ -32,29 +29,23 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( ` const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + - "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1 AND key_type = $2" const upsertCrossSigningKeysForUserSQL = "" + - "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data, updatable_without_uia_before_ms)" + + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3, $4)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" -const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + - "UPDATE keyserver_cross_signing_keys" + - " SET updatable_without_uia_before_ms = $1" + - " WHERE user_id = $2 AND key_type = $3" - type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt - updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -66,12 +57,6 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations( - sqlutil.Migration{ - Version: "userapi: add x-signing updatable_without_uia_before_ms", - Up: deltas.UpAddXSigningUpdatableWithoutUIABeforeMs, - }, - ) err = m.Up(context.Background()) if err != nil { return nil, err @@ -80,7 +65,6 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - {&s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt, updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL}, }.Prepare(db) } @@ -96,18 +80,14 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - var updatableWithoutUIABeforeMs *int64 - if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = types.CrossSigningKey{ - UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, - KeyData: keyData, - } + r[keyType] = keyData } err = rows.Err() return @@ -129,45 +109,28 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - var updatableWithoutUIABeforeMs *int64 - if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = types.CrossSigningKey{ - UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, - KeyData: keyData, - } + r[keyType] = keyData } err = rows.Err() return } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { return fmt.Errorf("unknown key purpose %q", keyType) } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData, updatableWithoutUIABeforeMs); err != nil { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil } - -func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { - keyTypeInt := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] - ts := time.Now().Add(duration).UnixMilli() - result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) - if err != nil { - return -1, err - } - if n, _ := result.RowsAffected(); n == 0 { - return -1, sql.ErrNoRows - } - return ts, nil -} diff --git a/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go b/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go deleted file mode 100644 index 11f32aecb..000000000 --- a/userapi/storage/postgres/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go +++ /dev/null @@ -1,23 +0,0 @@ -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys ADD COLUMN IF NOT EXISTS updatable_without_uia_before_ms BIGINT DEFAULT NULL;`) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys DROP COLUMN IF EXISTS updatable_without_uia_before_ms;`) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 3d6d51edc..834c76488 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -1136,12 +1136,12 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string } results := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.KeyData.Encode()) + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) result := fclient.CrossSigningKey{ UserID: userID, Usage: []fclient.CrossSigningKeyPurpose{purpose}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{ - keyID: key.KeyData, + keyID: key, }, } sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) @@ -1183,10 +1183,10 @@ func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserI } // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap, updatableWithoutUIABeforeMs *int64) error { +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for keyType, key := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key.KeyData, key.UpdatableWithoutUIABeforeMs); err != nil { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key); err != nil { return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) } } @@ -1194,18 +1194,6 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s }) } -// UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA updates the 'updatable_without_uia_before_ms' attribute of the master cross-signing key. -// Normally this attribute depending on its value marks the master key as replaceable without UIA. -func (d *KeyDatabase) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, userID string, duration time.Duration) (int64, error) { - var ts int64 - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - ts, err = d.CrossSigningKeysTable.UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx, txn, userID, duration) - return err - }) - return ts, err -} - // StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/device. func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index 19f9b6c19..c57ffd398 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -10,9 +10,6 @@ import ( "context" "database/sql" "fmt" - "time" - - "github.com/element-hq/dendrite/userapi/storage/sqlite3/deltas" "github.com/element-hq/dendrite/internal" "github.com/element-hq/dendrite/internal/sqlutil" @@ -32,28 +29,22 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( ` const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + - "SELECT key_type, key_data, updatable_without_uia_before_ms FROM keyserver_cross_signing_keys" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1 AND key_type = $2" const upsertCrossSigningKeysForUserSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data, updatable_without_uia_before_ms)" + - " VALUES($1, $2, $3, $4)" - -const updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL = "" + - "UPDATE keyserver_cross_signing_keys" + - " SET updatable_without_uia_before_ms = $1" + - " WHERE user_id = $2 AND key_type = $3" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt - updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -65,12 +56,6 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) return nil, err } m := sqlutil.NewMigrator(db) - m.AddMigrations( - sqlutil.Migration{ - Version: "userapi: add x-signing updatable_without_uia_before_ms", - Up: deltas.UpAddXSigningUpdatableWithoutUIABeforeMs, - }, - ) err = m.Up(context.Background()) if err != nil { return nil, err @@ -79,7 +64,6 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - {&s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt, updateMasterCrossSigningKeyAllowReplacementWithoutUiaSQL}, }.Prepare(db) } @@ -95,18 +79,14 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - var updatableWithoutUiaBeforeMs *int64 - if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUiaBeforeMs); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = types.CrossSigningKey{ - UpdatableWithoutUIABeforeMs: updatableWithoutUiaBeforeMs, - KeyData: keyData, - } + r[keyType] = keyData } err = rows.Err() return @@ -128,45 +108,27 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - var updatableWithoutUIABeforeMs *int64 - if err = rows.Scan(&keyTypeInt, &keyData, &updatableWithoutUIABeforeMs); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] if !ok { return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) } - r[keyType] = types.CrossSigningKey{ - UpdatableWithoutUIABeforeMs: updatableWithoutUIABeforeMs, - KeyData: keyData, - } + r[keyType] = keyData } err = rows.Err() return } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64, -) error { + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { return fmt.Errorf("unknown key purpose %q", keyType) } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData, updatableWithoutUIABeforeMs); err != nil { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil } - -func (s *crossSigningKeysStatements) UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) { - keyTypeInt := types.KeyTypePurposeToInt[fclient.CrossSigningKeyPurposeMaster] - ts := time.Now().Add(duration).UnixMilli() - result, err := sqlutil.TxStmt(txn, s.updateMasterCrossSigningKeyAllowReplacementWithoutUiaStmt).ExecContext(ctx, ts, userID, keyTypeInt) - if err != nil { - return -1, err - } - if n, _ := result.RowsAffected(); n == 0 { - return -1, sql.ErrNoRows - } - return ts, nil -} diff --git a/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go b/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go deleted file mode 100644 index 2935509a7..000000000 --- a/userapi/storage/sqlite3/deltas/2025011001110000_add_xsigning_updatable_without_uia_before_ms.go +++ /dev/null @@ -1,23 +0,0 @@ -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys ADD COLUMN updatable_without_uia_before_ms BIGINT DEFAULT NULL;`) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownAddXSigningUpdatableWithoutUIABeforeMs(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, `ALTER TABLE keyserver_cross_signing_keys DROP COLUMN updatable_without_uia_before_ms;`) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 8e629914e..434702761 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -199,8 +199,7 @@ type StaleDeviceLists interface { type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) SelectCrossSigningKeysForUserAndKeyType(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, updatableWithoutUIABeforeMs *int64) error - UpdateMasterCrossSigningKeyAllowReplacementWithoutUIA(ctx context.Context, txn *sql.Tx, userID string, duration time.Duration) (int64, error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error } type CrossSigningSigs interface { diff --git a/userapi/types/storage.go b/userapi/types/storage.go index 3e58c8e67..971f3dc9a 100644 --- a/userapi/types/storage.go +++ b/userapi/types/storage.go @@ -37,13 +37,8 @@ var KeyTypeIntToPurpose = map[int16]fclient.CrossSigningKeyPurpose{ 3: fclient.CrossSigningKeyPurposeUserSigning, } -type CrossSigningKey struct { - UpdatableWithoutUIABeforeMs *int64 - KeyData spec.Base64Bytes -} - // Map of purpose -> public key -type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]CrossSigningKey +type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]spec.Base64Bytes // Map of user ID -> key ID -> signature type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes From 27f7a5e3ebd2652814edf4af85fcc14a34ff05a2 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 25 Jan 2025 00:56:19 +0000 Subject: [PATCH 60/71] mas: fix tests --- clientapi/admin_test.go | 14 ++++---------- clientapi/routing/admin.go | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 2d6cc1086..ca97b2f46 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -2021,16 +2021,10 @@ func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) { t.Errorf("failed to create account: %s", err) } _ = userAPI.KeyDatabase.StoreCrossSigningKeysForUser(ctx, alice.ID, types.CrossSigningKeyMap{ - fclient.CrossSigningKeyPurposeMaster: types.CrossSigningKey{ - KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), - }, - fclient.CrossSigningKeyPurposeSelfSigning: types.CrossSigningKey{ - KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), - }, - fclient.CrossSigningKeyPurposeUserSigning: types.CrossSigningKey{ - KeyData: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), - }, - }, nil) + fclient.CrossSigningKeyPurposeMaster: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + fclient.CrossSigningKeyPurposeSelfSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + fclient.CrossSigningKeyPurposeUserSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"), + }) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 799ece779..a34a5765a 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -801,6 +801,23 @@ func AdminAllowCrossSigningReplacementWithoutUIA( } } + var rs api.QueryAccountByLocalpartResponse + err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{ + Localpart: userID.Local(), + ServerName: userID.Domain(), + }, &rs) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountByLocalpart") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(err.Error()), + } + } else if errors.Is(err, sql.ErrNoRows) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("User not found."), + } + } switch req.Method { case http.MethodPost: ts := sessions.allowCrossSigningKeysReplacement(userID.String()) From c1ad175178a6e4921ef0ba0e4d969d5060fe990e Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 25 Jan 2025 01:05:13 +0000 Subject: [PATCH 61/71] more test fixes --- clientapi/routing/key_crosssigning_test.go | 10 ++-------- userapi/storage/postgres/cross_signing_keys_table.go | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 2d0893626..0db15ab92 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -10,8 +10,6 @@ import ( "strings" "testing" - "github.com/element-hq/dendrite/userapi/types" - "github.com/element-hq/dendrite/setup/config" "github.com/element-hq/dendrite/test" "github.com/element-hq/dendrite/test/testrig" @@ -134,9 +132,7 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) { }, queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ "@user:example.com": { - Key: &types.CrossSigningKey{ - KeyData: spec.Base64Bytes("key1"), - }, + Key: spec.Base64Bytes("key1"), }, }, } @@ -200,9 +196,7 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) { }, queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ "@user:example.com": { - Key: &types.CrossSigningKey{ - KeyData: spec.Base64Bytes("different_key"), - }, + Key: spec.Base64Bytes("different_key"), }, }, } diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 38e6a3e26..f05f7845a 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -38,7 +38,7 @@ const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + const upsertCrossSigningKeysForUserSQL = "" + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3, $4)" + + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" type crossSigningKeysStatements struct { From b8ea41b2adab75e64aa95116e76c59525474fcca Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Sat, 25 Jan 2025 02:02:24 +0000 Subject: [PATCH 62/71] tests for sessionsDict.crossSigningKeysReplacement --- clientapi/routing/register.go | 36 ++++++++++---------- clientapi/routing/register_test.go | 54 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 74140d4ca..da43a6b01 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -66,16 +66,16 @@ type sessionsDict struct { // If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2, // the delete request will fail for device2 since the UIA was initiated by trying to delete device1. deleteSessionToDeviceID map[string]string - // allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating + // crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating // cross-signing keys without UIA. - allowedForCrossSigningKeysReplacement map[string]*time.Timer + crossSigningKeysReplacement map[string]*time.Timer } // defaultTimeout is the timeout used to clean up sessions const defaultTimeOut = time.Minute * 5 -// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA -const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10 +// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA +const crossSigningKeysReplacementDuration = time.Minute * 10 // getCompletedStages returns the completed stages for a session. func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { @@ -128,14 +128,14 @@ func (d *sessionsDict) deleteSession(sessionID string) { func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { d.Lock() defer d.Unlock() - ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli() - t, ok := d.allowedForCrossSigningKeysReplacement[userID] + ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli() + t, ok := d.crossSigningKeysReplacement[userID] if ok { - t.Reset(allowedForCrossSigningKeysReplacementDuration) + t.Reset(crossSigningKeysReplacementDuration) return ts } - d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc( - allowedForCrossSigningKeysReplacementDuration, + d.crossSigningKeysReplacement[userID] = time.AfterFunc( + crossSigningKeysReplacementDuration, func() { d.restrictCrossSigningKeysReplacement(userID) }, @@ -146,14 +146,14 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { d.RLock() defer d.RUnlock() - _, ok := d.allowedForCrossSigningKeysReplacement[userID] + _, ok := d.crossSigningKeysReplacement[userID] return ok } func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { d.Lock() defer d.Unlock() - t, ok := d.allowedForCrossSigningKeysReplacement[userID] + t, ok := d.crossSigningKeysReplacement[userID] if ok { if !t.Stop() { select { @@ -161,18 +161,18 @@ func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) { default: } } - delete(d.allowedForCrossSigningKeysReplacement, userID) + delete(d.crossSigningKeysReplacement, userID) } } func newSessionsDict() *sessionsDict { return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), - sessionCompletedResult: make(map[string]registerResponse), - params: make(map[string]registerRequest), - timer: make(map[string]*time.Timer), - deleteSessionToDeviceID: make(map[string]string), - allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer), + sessions: make(map[string][]authtypes.LoginType), + sessionCompletedResult: make(map[string]registerResponse), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + deleteSessionToDeviceID: make(map[string]string), + crossSigningKeysReplacement: make(map[string]*time.Timer), } } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 71cc0ca67..8529d7c59 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { assert.Equal(t, expectedDisplayName, profile.DisplayName) }) } + +func TestCrossSigningKeysReplacement(t *testing.T) { + userID := "@user:example.com" + + t.Run("Can add new session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.Contains(t, s.crossSigningKeysReplacement, userID) + }) + + t.Run("Can check if session exists or not", func(t *testing.T) { + s := newSessionsDict() + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + }) + + t.Run("not exists", func(t *testing.T) { + assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com")) + }) + }) + + t.Run("Can deactivate session", func(t *testing.T) { + s := newSessionsDict() + assert.Empty(t, s.crossSigningKeysReplacement) + t.Run("not exists", func(t *testing.T) { + s.restrictCrossSigningKeysReplacement("@random:test.com") + assert.Empty(t, s.crossSigningKeysReplacement) + }) + + t.Run("exists", func(t *testing.T) { + s.allowCrossSigningKeysReplacement(userID) + s.restrictCrossSigningKeysReplacement(userID) + assert.Empty(t, s.crossSigningKeysReplacement) + }) + }) + + t.Run("Can erase expired sessions", func(t *testing.T) { + s := newSessionsDict() + s.allowCrossSigningKeysReplacement(userID) + assert.Len(t, s.crossSigningKeysReplacement, 1) + assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID)) + timer := s.crossSigningKeysReplacement[userID] + + // pretending the timer is expired + timer.Reset(time.Millisecond) + time.Sleep(time.Millisecond * 500) + + assert.Empty(t, s.crossSigningKeysReplacement) + }) +} From 8df644263cb28d6c7844517ea92e8a409cefa3b8 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 5 Feb 2025 20:56:11 +0000 Subject: [PATCH 63/71] msd3861: added license headers for the new files --- clientapi/admin_test.go | 5 +++++ clientapi/auth/default_user_verifier.go | 5 +++++ clientapi/routing/admin.go | 5 +++++ setup/mscs/msc3861/msc3861.go | 5 +++++ setup/mscs/msc3861/msc3861_user_verifier.go | 5 +++++ setup/mscs/msc3861/msc3861_user_verifier_test.go | 5 +++++ .../deltas/2024123101250000_drop_primary_key_constraint.go | 5 +++++ userapi/storage/postgres/localpart_external_ids_table.go | 5 +++++ .../deltas/2024123101150000_drop_primary_key_constraint.go | 5 +++++ userapi/storage/sqlite3/localpart_external_ids_table.go | 5 +++++ 10 files changed, 50 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index ca97b2f46..02e89649d 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package clientapi import ( diff --git a/clientapi/auth/default_user_verifier.go b/clientapi/auth/default_user_verifier.go index 54147b772..e6ecaf23e 100644 --- a/clientapi/auth/default_user_verifier.go +++ b/clientapi/auth/default_user_verifier.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package auth import ( diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index a34a5765a..7cc03ebeb 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package routing import ( diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index 8a0df647d..9e3d00123 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package msc3861 import ( diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 1e203f278..0de9b7fc9 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package msc3861 import ( diff --git a/setup/mscs/msc3861/msc3861_user_verifier_test.go b/setup/mscs/msc3861/msc3861_user_verifier_test.go index 95b4d0e9d..fd1d22a92 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier_test.go +++ b/setup/mscs/msc3861/msc3861_user_verifier_test.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package msc3861 import ( diff --git a/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go index 0bf7d3863..e88423361 100644 --- a/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go +++ b/userapi/storage/postgres/deltas/2024123101250000_drop_primary_key_constraint.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package deltas import ( diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go index e0f0d34af..bc2adac20 100644 --- a/userapi/storage/postgres/localpart_external_ids_table.go +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package postgres import ( diff --git a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go index 3758819bd..d66053530 100644 --- a/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go +++ b/userapi/storage/sqlite3/deltas/2024123101150000_drop_primary_key_constraint.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package deltas import ( diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go index 43ae2619c..acbd5a7e9 100644 --- a/userapi/storage/sqlite3/localpart_external_ids_table.go +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -1,3 +1,8 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. + package sqlite3 import ( From 3eb4c7e1bd243ab3e571008835013abb5ea3803f Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 16:13:32 +0000 Subject: [PATCH 64/71] msc3861: cr fixes --- clientapi/admin_test.go | 23 +++++++++++++---------- clientapi/routing/admin.go | 34 ++++++++++++++++------------------ clientapi/routing/register.go | 6 +++--- clientapi/routing/routing.go | 8 ++------ internal/httputil/httpapi.go | 2 +- setup/monolith.go | 2 +- 6 files changed, 36 insertions(+), 39 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 02e89649d..160fe8f36 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1561,16 +1561,19 @@ func TestAdminCheckUsernameAvailable(t *testing.T) { t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) } + // Nothing more to check, test is done. if tc.wantOK { - b := make(map[string]bool, 1) - _ = json.NewDecoder(rec.Body).Decode(&b) - available, ok := b["available"] - if !ok { - t.Fatal("'available' not found in body") - } - if available != tc.isAvailable { - t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available) - } + return + } + + b := make(map[string]bool, 1) + _ = json.NewDecoder(rec.Body).Decode(&b) + available, ok := b["available"] + if !ok { + t.Fatal("'available' not found in body") + } + if available != tc.isAvailable { + t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available) } }) } @@ -2311,7 +2314,7 @@ func TestAdminRetrieveAccount(t *testing.T) { } for _, tc := range testCase { - t.Run("Retrieve existing account", func(t *testing.T) { + t.Run(tc.Name, func(t *testing.T) { req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+tc.User.ID) req.Header.Set("Authorization", "Bearer "+adminToken) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 7cc03ebeb..79494d8f5 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -560,26 +560,24 @@ func AdminUserDeviceRetrieveCreate( } userDeviceExists := false - { - var rs api.QueryDevicesResponse - if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { - logger.WithError(err).Error("QueryDevices") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } + var rs api.QueryDevicesResponse + if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil { + logger.WithError(err).Error("QueryDevices") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } - if !rs.UserExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound("Given user ID does not exist"), - } + } + if !rs.UserExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Given user ID does not exist"), } - for i := range rs.Devices { - if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID { - userDeviceExists = true - break - } + } + for i := range rs.Devices { + if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID { + userDeviceExists = true + break } } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index da43a6b01..8fbebf1b9 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -128,11 +128,11 @@ func (d *sessionsDict) deleteSession(sessionID string) { func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { d.Lock() defer d.Unlock() - ts := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli() + allowedUntilTS := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli() t, ok := d.crossSigningKeysReplacement[userID] if ok { t.Reset(crossSigningKeysReplacementDuration) - return ts + return allowedUntilTS } d.crossSigningKeysReplacement[userID] = time.AfterFunc( crossSigningKeysReplacementDuration, @@ -140,7 +140,7 @@ func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 { d.restrictCrossSigningKeysReplacement(userID) }, ) - return ts + return allowedUntilTS } func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 15a5addfb..65465c58d 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -349,14 +349,10 @@ func Setup( })).Methods(http.MethodPost) synapseAdminRouter.Handle("/admin/v2/users/{userID}", httputil.MakeServiceAdminAPI("admin_manage_user", m.AdminToken, func(r *http.Request) util.JSONResponse { - switch r.Method { - case http.MethodGet: + if r.Method == http.MethodGet { return AdminRetrieveAccount(r, cfg, userAPI) - case http.MethodPut: - return AdminCreateOrModifyAccount(r, userAPI, cfg) - default: - return util.JSONResponse{Code: http.StatusMethodNotAllowed, JSON: nil} } + return AdminCreateOrModifyAccount(r, userAPI, cfg) })).Methods(http.MethodPut, http.MethodGet) synapseAdminRouter.Handle("/admin/v2/users/{userID}/devices", httputil.MakeServiceAdminAPI("admin_create_retrieve_user_devices", m.AdminToken, func(r *http.Request) util.JSONResponse { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 65a2db2e0..7eee0d694 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -163,7 +163,7 @@ func MakeServiceAdminAPI( } } if token != serviceToken { - logger.Debugf("Invalid service token '%s'", token) + logger.Debug("Invalid service token") return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.UnknownToken(token), diff --git a/setup/monolith.go b/setup/monolith.go index 8d8fadc90..7dd677051 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -51,7 +51,7 @@ type Monolith struct { ExtPublicRoomsProvider api.ExtraPublicRoomsProvider ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI - UserVerifierProvider *UserVerifierProvider + UserVerifierProvider httputil.UserVerifier } // AddAllPublicRoutes attaches all public paths to the given router From f91cc64401ab84b7d4a01850577fc147f58a9f14 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 16:40:22 +0000 Subject: [PATCH 65/71] msc3861: cr fixes --- setup/mscs/msc3861/msc3861.go | 2 +- .../storage/postgres/localpart_external_ids_table.go | 12 ++++++------ userapi/storage/shared/storage.go | 6 +++--- .../storage/sqlite3/localpart_external_ids_table.go | 12 ++++++------ userapi/storage/tables/interface.go | 6 +++--- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index 9e3d00123..4dd5d517c 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -20,6 +20,6 @@ func Enable(m *setup.Monolith) error { if err != nil { return err } - m.UserVerifierProvider.UserVerifier = userVerifier + m.UserVerifierProvider = setup.NewUserVerifierProvider(userVerifier) return nil } diff --git a/userapi/storage/postgres/localpart_external_ids_table.go b/userapi/storage/postgres/localpart_external_ids_table.go index bc2adac20..8e6a81d4b 100644 --- a/userapi/storage/postgres/localpart_external_ids_table.go +++ b/userapi/storage/postgres/localpart_external_ids_table.go @@ -67,8 +67,8 @@ func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalI }.Prepare(db) } -// Select selects an existing OpenID Connect connection from the database -func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { +// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { ret := api.LocalpartExternalID{ ExternalID: externalID, AuthProvider: authProvider, @@ -87,15 +87,15 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, return &ret, nil } -// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. -func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { +// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) return err } -// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. -func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { +// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) _, err := stmt.ExecContext(ctx, externalID, authProvider) return err diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 834c76488..17140e69c 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -907,15 +907,15 @@ func (d *Database) UpsertPusher( } func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error { - return d.LocalpartExternalIDs.Insert(ctx, nil, localpart, externalID, authProvider) + return d.LocalpartExternalIDs.InsertLocalExternalPartID(ctx, nil, localpart, externalID, authProvider) } func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { - return d.LocalpartExternalIDs.Select(ctx, nil, externalID, authProvider) + return d.LocalpartExternalIDs.SelectLocalExternalPartID(ctx, nil, externalID, authProvider) } func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error { - return d.LocalpartExternalIDs.Delete(ctx, nil, externalID, authProvider) + return d.LocalpartExternalIDs.DeleteLocalExternalPartID(ctx, nil, externalID, authProvider) } // GetPushers returns the pushers matching the given localpart. diff --git a/userapi/storage/sqlite3/localpart_external_ids_table.go b/userapi/storage/sqlite3/localpart_external_ids_table.go index acbd5a7e9..e5074625a 100644 --- a/userapi/storage/sqlite3/localpart_external_ids_table.go +++ b/userapi/storage/sqlite3/localpart_external_ids_table.go @@ -67,8 +67,8 @@ func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDs }.Prepare(db) } -// Select selects an existing OpenID Connect connection from the database -func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { +// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database +func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) { ret := api.LocalpartExternalID{ ExternalID: externalID, AuthProvider: authProvider, @@ -87,15 +87,15 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx, return &ret, nil } -// Insert creates a new record representing an OpenID Connect connection between Matrix and external accounts. -func (u *localpartExternalIDStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { +// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts. +func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error { stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt) _, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix()) return err } -// Delete deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. -func (u *localpartExternalIDStatements) Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { +// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account. +func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error { stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt) _, err := stmt.ExecContext(ctx, externalID, authProvider) return err diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 434702761..fd9f10c30 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -128,9 +128,9 @@ type StatsTable interface { } type LocalpartExternalIDsTable interface { - Select(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) - Insert(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error - Delete(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error + SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) + InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error + DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error } type NotificationFilter uint32 From fd52c7eb1f0e62f0db0ea198a2a6c4ac5674429b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 17:45:31 +0000 Subject: [PATCH 66/71] msc3861: cr --- setup/monolith.go | 4 ++-- setup/mscs/msc3861/msc3861_user_verifier.go | 23 +++++++-------------- userapi/api/api.go | 20 ++---------------- userapi/internal/user_api.go | 9 ++++---- 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/setup/monolith.go b/setup/monolith.go index 7dd677051..b61633c1a 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -89,11 +89,11 @@ func (m *Monolith) AddAllPublicRoutes( } type UserVerifierProvider struct { - UserVerifier httputil.UserVerifier + httputil.UserVerifier } func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { - return u.UserVerifier.VerifyUserFromRequest(req) + return u.VerifyUserFromRequest(req) } func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider { diff --git a/setup/mscs/msc3861/msc3861_user_verifier.go b/setup/mscs/msc3861/msc3861_user_verifier.go index 0de9b7fc9..30f578dcb 100644 --- a/setup/mscs/msc3861/msc3861_user_verifier.go +++ b/setup/mscs/msc3861/msc3861_user_verifier.go @@ -202,17 +202,12 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } localpart := "" - { - var rs api.QueryLocalpartExternalIDResponse - if err = m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, &api.QueryLocalpartExternalIDRequest{ - ExternalID: sub, - AuthProvider: externalAuthProvider, - }, &rs); err != nil && err != sql.ErrNoRows { - return nil, err - } - if l := rs.LocalpartExternalID; l != nil { - localpart = l.Localpart - } + localpartExternalID, err := m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, sub, externalAuthProvider) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + if localpartExternalID != nil { + localpart = localpartExternalID.Localpart } if localpart == "" { @@ -253,11 +248,7 @@ func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token st } } - if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, &api.PerformLocalpartExternalUserIDCreationRequest{ - Localpart: userID.Local(), - ExternalID: sub, - AuthProvider: externalAuthProvider, - }); err != nil { + if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, userID.Local(), sub, externalAuthProvider); err != nil { logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation") return nil, err } diff --git a/userapi/api/api.go b/userapi/api/api.go index 9b1319986..31059f5ad 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -31,8 +31,8 @@ type UserInternalAPI interface { FederationUserAPI QuerySearchProfilesAPI // used by p2p demos - QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *QueryLocalpartExternalIDRequest, res *QueryLocalpartExternalIDResponse) (err error) - PerformLocalpartExternalUserIDCreation(ctx context.Context, req *PerformLocalpartExternalUserIDCreationRequest) (err error) + QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*LocalpartExternalID, error) + PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (error) } // api functions required by the appservice api @@ -667,22 +667,6 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } - -type QueryLocalpartExternalIDRequest struct { - ExternalID string - AuthProvider string -} - -type QueryLocalpartExternalIDResponse struct { - LocalpartExternalID *LocalpartExternalID -} - -type PerformLocalpartExternalUserIDCreationRequest struct { - Localpart string - ExternalID string - AuthProvider string -} - // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 2b500c95d..e4c846e56 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -604,13 +604,12 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api. return } -func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, req *api.PerformLocalpartExternalUserIDCreationRequest) (err error) { - return a.DB.CreateLocalpartExternalID(ctx, req.Localpart, req.ExternalID, req.AuthProvider) +func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (err error) { + return a.DB.CreateLocalpartExternalID(ctx, localpart, externalID, authProvider) } -func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, req *api.QueryLocalpartExternalIDRequest, res *api.QueryLocalpartExternalIDResponse) (err error) { - res.LocalpartExternalID, err = a.DB.GetLocalpartForExternalID(ctx, req.ExternalID, req.AuthProvider) - return +func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) { + return a.DB.GetLocalpartForExternalID(ctx, externalID, authProvider) } // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem From ff2ba0313a04aaba1af1f19a82bd2bb0edfd1a0c Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 18:01:33 +0000 Subject: [PATCH 67/71] msc3861: ++ --- clientapi/admin_test.go | 2 +- setup/monolith.go | 2 +- userapi/api/api.go | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 160fe8f36..51a9dfba2 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -1562,7 +1562,7 @@ func TestAdminCheckUsernameAvailable(t *testing.T) { } // Nothing more to check, test is done. - if tc.wantOK { + if !tc.wantOK { return } diff --git a/setup/monolith.go b/setup/monolith.go index b61633c1a..915446fe8 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -93,7 +93,7 @@ type UserVerifierProvider struct { } func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) { - return u.VerifyUserFromRequest(req) + return u.UserVerifier.VerifyUserFromRequest(req) } func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider { diff --git a/userapi/api/api.go b/userapi/api/api.go index 31059f5ad..ec3ae5f31 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -32,7 +32,7 @@ type UserInternalAPI interface { QuerySearchProfilesAPI // used by p2p demos QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*LocalpartExternalID, error) - PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (error) + PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) error } // api functions required by the appservice api @@ -667,6 +667,7 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } + // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI From c490badadc48351f9b3f6816a1d6889e68fa9f9b Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 20:40:00 +0000 Subject: [PATCH 68/71] msc3861: delete QueryMasterKeys function and related as it's redundant and no longer needed --- clientapi/routing/key_crosssigning.go | 11 +---- setup/mscs/msc3861/msc3861.go | 15 ++++++- userapi/api/api.go | 11 ----- userapi/internal/key_api.go | 13 ------ userapi/storage/interface.go | 1 - .../postgres/cross_signing_keys_table.go | 41 ++----------------- userapi/storage/shared/storage.go | 5 --- .../sqlite3/cross_signing_keys_table.go | 41 ++----------------- userapi/storage/tables/interface.go | 1 - 9 files changed, 21 insertions(+), 118 deletions(-) diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index a0f7f06e1..b0edb6062 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -31,7 +31,6 @@ type crossSigningRequest struct { type UploadKeysAPI interface { QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) - QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) api.UploadDeviceKeysAPI } @@ -76,15 +75,7 @@ func UploadCrossSigningDeviceKeys( // With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable. if cfg.MSCs.MSC3861Enabled() { - masterKeyResp := api.QueryMasterKeysResponse{} - keyserverAPI.QueryMasterKeys(req.Context(), &api.QueryMasterKeysRequest{UserID: device.UserID}, &masterKeyResp) - - if masterKeyResp.Error != nil { - logger.WithError(masterKeyResp.Error).Error("Failed to query master key") - return convertKeyError(masterKeyResp.Error) - } - - requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil + requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) if requireUIA { url := "" if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" { diff --git a/setup/mscs/msc3861/msc3861.go b/setup/mscs/msc3861/msc3861.go index 4dd5d517c..cda2266c4 100644 --- a/setup/mscs/msc3861/msc3861.go +++ b/setup/mscs/msc3861/msc3861.go @@ -6,6 +6,8 @@ package msc3861 import ( + "errors" + "github.com/element-hq/dendrite/setup" "github.com/matrix-org/gomatrixserverlib/fclient" ) @@ -20,6 +22,17 @@ func Enable(m *setup.Monolith) error { if err != nil { return err } - m.UserVerifierProvider = setup.NewUserVerifierProvider(userVerifier) + + if m.UserVerifierProvider == nil { + return errors.New("msc3861: UserVerifierProvider is not initialised") + } + + provider, ok := m.UserVerifierProvider.(*setup.UserVerifierProvider) + if !ok { + return errors.New("msc3861: the expected type of m.UserVerifierProvider is *setup.UserVerifierProvider") + } + + provider.UserVerifier = userVerifier + return nil } diff --git a/userapi/api/api.go b/userapi/api/api.go index ec3ae5f31..3c46b769b 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -672,7 +672,6 @@ type QueryAccountByLocalpartResponse struct { type ClientKeyAPI interface { UploadDeviceKeysAPI QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) - QueryMasterKeys(ctx context.Context, req *QueryMasterKeysRequest, res *QueryMasterKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) @@ -934,16 +933,6 @@ type QueryKeysResponse struct { Error *KeyError } -type QueryMasterKeysRequest struct { - UserID string -} - -type QueryMasterKeysResponse struct { - Key spec.Base64Bytes - // Set if there was a fatal error processing this query - Error *KeyError -} - type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index eb7597ab9..24148eea0 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -234,19 +234,6 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) } -func (a *UserInternalAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { - crossSigningKeyMap, err := a.KeyDatabase.CrossSigningKeysDataForUserAndKeyType(ctx, req.UserID, fclient.CrossSigningKeyPurposeMaster) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query user cross signing master keys: %s", err), - } - return - } - if key, ok := crossSigningKeyMap[fclient.CrossSigningKeyPurposeMaster]; ok { - res.Key = key - } -} - // nolint:gocyclo func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { var respMu sync.Mutex diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 3cf7e7659..41db12220 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -227,7 +227,6 @@ type KeyDatabase interface { CrossSigningKeysForUser(ctx context.Context, userID string) (map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey, error) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index f05f7845a..b0f92cf81 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -32,20 +32,15 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" -const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1 AND key_type = $2" - const upsertCrossSigningKeysForUserSQL = "" + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -63,7 +58,6 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -93,35 +87,6 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, -) (r types.CrossSigningKeyMap, err error) { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return nil, fmt.Errorf("unknown key purpose %q", keyType) - } - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData spec.Base64Bytes - if err = rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - err = rows.Err() - return -} - func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 17140e69c..e7f7789e1 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -1172,11 +1172,6 @@ func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID st return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } -// CrossSigningKeysForUserAndKeyType returns the latest known cross-signing keys for a user and key type, if any. -func (d *KeyDatabase) CrossSigningKeysDataForUserAndKeyType(ctx context.Context, userID string, keyType fclient.CrossSigningKeyPurpose) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUserAndKeyType(ctx, nil, userID, keyType) -} - // CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index c57ffd398..e34e0d36b 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -32,19 +32,14 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" -const selectCrossSigningKeysForUserAndKeyTypeSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1 AND key_type = $2" - const upsertCrossSigningKeysForUserSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - selectCrossSigningKeysForUserAndKeyTypeStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -62,7 +57,6 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) } return s, sqlutil.StatementList{ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.selectCrossSigningKeysForUserAndKeyTypeStmt, selectCrossSigningKeysForUserAndKeyTypeSQL}, {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, }.Prepare(db) } @@ -92,35 +86,6 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUserAndKeyType( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, -) (r types.CrossSigningKeyMap, err error) { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return nil, fmt.Errorf("unknown key purpose %q", keyType) - } - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserAndKeyTypeStmt).QueryContext(ctx, userID, keyTypeInt) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectCrossSigningKeysForUserAndKeyType: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData spec.Base64Bytes - if err = rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - err = rows.Err() - return -} - func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index fd9f10c30..e0dedcb32 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -198,7 +198,6 @@ type StaleDeviceLists interface { type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - SelectCrossSigningKeysForUserAndKeyType(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose) (r types.CrossSigningKeyMap, err error) UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error } From 1b8a659ecab972e7c8d8807cb1e6e232827e7a96 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 20:54:46 +0000 Subject: [PATCH 69/71] msc3861: tests --- clientapi/routing/key_crosssigning_test.go | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index 0db15ab92..fac6a9607 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -24,7 +24,6 @@ import ( type mockKeyAPI struct { t *testing.T queryKeysData map[string]api.QueryKeysResponse - queryMasterKeysData map[string]api.QueryMasterKeysResponse } func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { @@ -36,14 +35,6 @@ func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, re } } -func (m mockKeyAPI) QueryMasterKeys(ctx context.Context, req *api.QueryMasterKeysRequest, res *api.QueryMasterKeysResponse) { - res.Key = m.queryMasterKeysData[req.UserID].Key - res.Error = m.queryMasterKeysData[req.UserID].Error - if m.t != nil { - m.t.Logf("QueryMasterKeys: %+v => %+v", req, res) - } -} - func (m mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Just a dummy upload which always succeeds } @@ -67,9 +58,6 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) { queryKeysData: map[string]api.QueryKeysResponse{ "@user:example.com": {}, }, - queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ - "@user:example.com": {}, - }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} cfg := &config.ClientAPI{ @@ -130,11 +118,6 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) { UserSigningKeys: nil, }, }, - queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ - "@user:example.com": { - Key: spec.Base64Bytes("key1"), - }, - }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} cfg := &config.ClientAPI{ @@ -194,11 +177,6 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) { }, }, }, - queryMasterKeysData: map[string]api.QueryMasterKeysResponse{ - "@user:example.com": { - Key: spec.Base64Bytes("different_key"), - }, - }, } device := &api.Device{UserID: "@user:example.com", ID: "device"} From b74b52de4ad23afcc1cf1a58d8a7325367112407 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 20:55:01 +0000 Subject: [PATCH 70/71] remove deprecated linters from golangci-lint --- .golangci.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 6f3fd3627..51650170e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -167,7 +167,6 @@ linters: - gosimple - govet - ineffassign - - megacheck - misspell # Check code comments, whereas misspell in CI checks *.md files - nakedret - staticcheck @@ -182,13 +181,9 @@ linters: - gochecknoinits - gocritic - gofmt - - golint - gosec # Should turn back on soon - - interfacer - lll - - maligned - prealloc # Should turn back on soon - - scopelint - stylecheck - typecheck # Should turn back on soon - unconvert # Should turn back on soon From 950555a5a580cde4380601147f0a97b384a2ad79 Mon Sep 17 00:00:00 2001 From: Roman Isaev Date: Wed, 12 Feb 2025 20:58:09 +0000 Subject: [PATCH 71/71] goimports --- clientapi/routing/key_crosssigning_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/key_crosssigning_test.go b/clientapi/routing/key_crosssigning_test.go index fac6a9607..9ee7a9697 100644 --- a/clientapi/routing/key_crosssigning_test.go +++ b/clientapi/routing/key_crosssigning_test.go @@ -22,8 +22,8 @@ import ( // TODO: add more tests to cover cases related to MSC3861 type mockKeyAPI struct { - t *testing.T - queryKeysData map[string]api.QueryKeysResponse + t *testing.T + queryKeysData map[string]api.QueryKeysResponse } func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {