diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index 46ed95c827..4dfddd8f3e 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -7427,10 +7427,35 @@ paths: in: query name: limit type: integer - - in: query + - description: Types of notifications to include. If not provided, all notification types will be included. + in: query items: + enum: + - follow + - follow_request + - mention + - reblog + - favourite + - poll + - status + - admin.sign_up + type: string + name: types[] + type: array + - description: Types of notifications to exclude. + in: query + items: + enum: + - follow + - follow_request + - mention + - reblog + - favourite + - poll + - status + - admin.sign_up type: string - name: exclude_types + name: exclude_types[] type: array produces: - application/json diff --git a/internal/api/client/notifications/notifications.go b/internal/api/client/notifications/notifications.go index 8e08904617..ab015427e7 100644 --- a/internal/api/client/notifications/notifications.go +++ b/internal/api/client/notifications/notifications.go @@ -34,7 +34,9 @@ const ( BasePathWithID = BasePath + "/:" + IDKey BasePathWithClear = BasePath + "/clear" - // ExcludeTypes is an array specifying notification types to exclude + // IncludeTypesKey names an array param specifying notification types to include. + IncludeTypesKey = "include_types[]" + // ExcludeTypesKey names an array param specifying notification types to exclude. ExcludeTypesKey = "exclude_types[]" MaxIDKey = "max_id" LimitKey = "limit" diff --git a/internal/api/client/notifications/notifications_test.go b/internal/api/client/notifications/notifications_test.go new file mode 100644 index 0000000000..23af65cb42 --- /dev/null +++ b/internal/api/client/notifications/notifications_test.go @@ -0,0 +1,109 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package notifications_test + +import ( + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/notifications" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type NotificationsTestSuite struct { + // standard suite interfaces + suite.Suite + db db.DB + tc *typeutils.Converter + mediaManager *media.Manager + federator *federation.Federator + emailSender email.Sender + processor *processing.Processor + storage *storage.Driver + state state.State + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + testStatuses map[string]*gtsmodel.Status + testFollows map[string]*gtsmodel.Follow + testNotifications map[string]*gtsmodel.Notification + + // module being tested + notificationsModule *notifications.Module +} + +func (suite *NotificationsTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testFollows = testrig.NewTestFollows() + suite.testNotifications = testrig.NewTestNotifications() +} + +func (suite *NotificationsTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartNoopWorkers(&suite.state) + + testrig.InitTestConfig() + testrig.InitTestLog() + + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db + suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + suite.tc = typeutils.NewConverter(&suite.state) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) + suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) + suite.notificationsModule = notifications.New(suite.processor) +} + +func (suite *NotificationsTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) +} diff --git a/internal/api/client/notifications/notificationsget.go b/internal/api/client/notifications/notificationsget.go index da43cffec1..533932d11b 100644 --- a/internal/api/client/notifications/notificationsget.go +++ b/internal/api/client/notifications/notificationsget.go @@ -80,11 +80,37 @@ import ( // in: query // required: false // - -// name: exclude_types +// name: types[] // type: array // items: // type: string -// description: Array of types of notifications to exclude (follow, favourite, reblog, mention, poll, follow_request) +// enum: +// - follow +// - follow_request +// - mention +// - reblog +// - favourite +// - poll +// - status +// - admin.sign_up +// description: Types of notifications to include. If not provided, all notification types will be included. +// in: query +// required: false +// - +// name: exclude_types[] +// type: array +// items: +// type: string +// enum: +// - follow +// - follow_request +// - mention +// - reblog +// - favourite +// - poll +// - status +// - admin.sign_up +// description: Types of notifications to exclude. // in: query // required: false // @@ -145,6 +171,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) { c.Query(SinceIDKey), c.Query(MinIDKey), limit, + c.QueryArray(IncludeTypesKey), c.QueryArray(ExcludeTypesKey), ) if errWithCode != nil { diff --git a/internal/api/client/notifications/notificationsget_test.go b/internal/api/client/notifications/notificationsget_test.go new file mode 100644 index 0000000000..118303cae8 --- /dev/null +++ b/internal/api/client/notifications/notificationsget_test.go @@ -0,0 +1,253 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package notifications_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/notifications" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *NotificationsTestSuite) getNotifications( + account *gtsmodel.Account, + token *gtsmodel.Token, + user *gtsmodel.User, + maxID string, + minID string, + limit int, + includeTypes []string, + excludeTypes []string, + expectedHTTPStatus int, + expectedBody string, +) ([]*apimodel.Notification, string, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, account) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(token)) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, user) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+notifications.BasePath, nil) + ctx.Request.Header.Set("accept", "application/json") + query := url.Values{} + if maxID != "" { + query.Set(notifications.MaxIDKey, maxID) + } + if minID != "" { + query.Set(notifications.MinIDKey, maxID) + } + if limit != 0 { + query.Set(notifications.LimitKey, strconv.Itoa(limit)) + } + if len(includeTypes) > 0 { + query[notifications.IncludeTypesKey] = includeTypes + } + if len(excludeTypes) > 0 { + query[notifications.ExcludeTypesKey] = excludeTypes + } + ctx.Request.URL.RawQuery = query.Encode() + + // trigger the handler + suite.notificationsModule.NotificationsGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, "", err + } + + errs := gtserror.NewMultiError(2) + + // check code + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, "", errs.Combine() + } + + resp := make([]*apimodel.Notification, 0) + if err := json.Unmarshal(b, &resp); err != nil { + return nil, "", err + } + + return resp, result.Header.Get("Link"), nil +} + +// Test that we can retrieve at least one notification and the expected Link header. +func (suite *NotificationsTestSuite) TestGetNotificationsSingle() { + testAccount := suite.testAccounts["local_account_1"] + testToken := suite.testTokens["local_account_1"] + testUser := suite.testUsers["local_account_1"] + + maxID := "" + minID := "" + limit := 10 + includeTypes := []string(nil) + excludeTypes := []string(nil) + expectedHTTPStatus := http.StatusOK + expectedBody := "" + + notifications, linkHeader, err := suite.getNotifications( + testAccount, + testToken, + testUser, + maxID, + minID, + limit, + includeTypes, + excludeTypes, + expectedHTTPStatus, + expectedBody, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(notifications, 1) + suite.Equal(`; rel="next", ; rel="prev"`, linkHeader) +} + +// Add some extra notifications of different types than the fixture's single fav notification per account. +func (suite *NotificationsTestSuite) addMoreNotifications(testAccount *gtsmodel.Account) { + for _, b := range []*gtsmodel.Notification{ + { + ID: id.NewULID(), + NotificationType: gtsmodel.NotificationFollowRequest, + TargetAccountID: testAccount.ID, + OriginAccountID: suite.testAccounts["local_account_2"].ID, + }, + { + ID: id.NewULID(), + NotificationType: gtsmodel.NotificationFollow, + TargetAccountID: testAccount.ID, + OriginAccountID: suite.testAccounts["remote_account_2"].ID, + }, + } { + if err := suite.db.Put(context.Background(), b); err != nil { + suite.FailNow(err.Error()) + } + } +} + +// Test that we can exclude a notification type. +func (suite *NotificationsTestSuite) TestGetNotificationsExcludeOneType() { + testAccount := suite.testAccounts["local_account_1"] + testToken := suite.testTokens["local_account_1"] + testUser := suite.testUsers["local_account_1"] + + suite.addMoreNotifications(testAccount) + + maxID := "" + minID := "" + limit := 10 + includeTypes := []string(nil) + excludeTypes := []string{"follow_request"} + expectedHTTPStatus := http.StatusOK + expectedBody := "" + + notifications, _, err := suite.getNotifications( + testAccount, + testToken, + testUser, + maxID, + minID, + limit, + includeTypes, + excludeTypes, + expectedHTTPStatus, + expectedBody, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + // This should not include the follow request notification. + suite.Len(notifications, 2) + for _, notification := range notifications { + suite.NotEqual("follow_request", notification.Type) + } +} + +// Test that we can fetch only a single notification type. +func (suite *NotificationsTestSuite) TestGetNotificationsIncludeOneType() { + testAccount := suite.testAccounts["local_account_1"] + testToken := suite.testTokens["local_account_1"] + testUser := suite.testUsers["local_account_1"] + + suite.addMoreNotifications(testAccount) + + maxID := "" + minID := "" + limit := 10 + includeTypes := []string{"favourite"} + excludeTypes := []string(nil) + expectedHTTPStatus := http.StatusOK + expectedBody := "" + + notifications, _, err := suite.getNotifications( + testAccount, + testToken, + testUser, + maxID, + minID, + limit, + includeTypes, + excludeTypes, + expectedHTTPStatus, + expectedBody, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + // This should only include the fav notification. + suite.Len(notifications, 1) + for _, notification := range notifications { + suite.Equal("favourite", notification.Type) + } +} + +func TestBookmarkTestSuite(t *testing.T) { + suite.Run(t, new(NotificationsTestSuite)) +} diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 04688a379c..af147ab080 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -200,6 +200,7 @@ func (n *notificationDB) GetAccountNotifications( sinceID string, minID string, limit int, + includeTypes []string, excludeTypes []string, ) ([]*gtsmodel.Notification, error) { // Ensure reasonable @@ -237,9 +238,14 @@ func (n *notificationDB) GetAccountNotifications( frontToBack = false // page up } - for _, excludeType := range excludeTypes { + if len(includeTypes) > 0 { + // Include only requested notification types. + q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(includeTypes)) + } + + if len(excludeTypes) > 0 { // Filter out unwanted notif types. - q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType) + q = q.Where("? NOT IN (?)", bun.Ident("notification.notification_type"), bun.In(excludeTypes)) } // Return only notifs for this account. diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go index 984c0ef8dd..eb2c020666 100644 --- a/internal/db/bundb/notification_test.go +++ b/internal/db/bundb/notification_test.go @@ -97,6 +97,7 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() { "", 20, nil, + nil, ) suite.NoError(err) timeTaken := time.Since(before) @@ -119,6 +120,7 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() { "", 20, nil, + nil, ) suite.NoError(err) timeTaken := time.Since(before) @@ -143,6 +145,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { "", 20, nil, + nil, ) if err != nil { suite.FailNow(err.Error()) @@ -163,6 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { "", 20, nil, + nil, ) if err != nil { suite.FailNow(err.Error()) @@ -184,6 +188,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() { "", 20, nil, + nil, ) suite.NoError(err) suite.Nil(notifications) diff --git a/internal/db/notification.go b/internal/db/notification.go index 9ff459b9cb..2e8f5ed1f7 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -25,12 +25,13 @@ import ( // Notification contains functions for creating and getting notifications. type Notification interface { - // GetNotifications returns a slice of notifications that pertain to the given accountID. + // GetAccountNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, error) + // If includeTypes is empty, *all* notification types will be included. + GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, includeTypes []string, excludeTypes []string) ([]*gtsmodel.Notification, error) - // GetNotification returns one notification according to its id. + // GetNotificationByID returns one notification according to its id. GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) // GetNotificationsByIDs returns a slice of notifications of the the provided IDs. diff --git a/internal/processing/timeline/notification.go b/internal/processing/timeline/notification.go index 5156a1cdfc..6976494440 100644 --- a/internal/processing/timeline/notification.go +++ b/internal/processing/timeline/notification.go @@ -34,8 +34,26 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, excludeTypes []string) (*apimodel.PageableResponse, gtserror.WithCode) { - notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, maxID, sinceID, minID, limit, excludeTypes) +func (p *Processor) NotificationsGet( + ctx context.Context, + authed *oauth.Auth, + maxID string, + sinceID string, + minID string, + limit int, + includeTypes []string, + excludeTypes []string, +) (*apimodel.PageableResponse, gtserror.WithCode) { + notifs, err := p.state.DB.GetAccountNotifications( + ctx, + authed.Account.ID, + maxID, + sinceID, + minID, + limit, + includeTypes, + excludeTypes, + ) if err != nil && !errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("NotificationsGet: db error getting notifications: %w", err) return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/workers/surfacenotify_test.go b/internal/processing/workers/surfacenotify_test.go index 7b448781d5..18d0277ae2 100644 --- a/internal/processing/workers/surfacenotify_test.go +++ b/internal/processing/workers/surfacenotify_test.go @@ -87,7 +87,7 @@ func (suite *SurfaceNotifyTestSuite) TestSpamNotifs() { notifs, err := testStructs.State.DB.GetAccountNotifications( gtscontext.SetBarebones(ctx), targetAccount.ID, - "", "", "", 0, nil, + "", "", "", 0, nil, nil, ) if err != nil { suite.FailNow(err.Error())