diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8df3d160e1..50270ae388 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,9 +20,20 @@ jobs: run: make test-integration env: REDIS_ADDRESS: redis:6379 + SQL_USERNAME: root + SQL_PASSWORD: my-secret-pw + SQL_ADDRESS: localhost:3306 + SQL_DBNAME: reva services: redis: image: registry.cern.ch/docker.io/webhippie/redis + mysql: + image: mysql + ports: + - 3306:3306 + env: + MYSQL_ROOT_PASSWORD: my-secret-pw + MYSQL_DATABASE: reva go: runs-on: self-hosted steps: diff --git a/changelog/unreleased/ocm-invite-sql-tests.md b/changelog/unreleased/ocm-invite-sql-tests.md new file mode 100644 index 0000000000..ae029a9a4d --- /dev/null +++ b/changelog/unreleased/ocm-invite-sql-tests.md @@ -0,0 +1,3 @@ +Enhancement: Tests for invitation manager SQL driver + +https://github.com/cs3org/reva/pull/3619 \ No newline at end of file diff --git a/tests/integration/grpc/fixtures/ocm-server-cernbox-grpc.toml b/tests/integration/grpc/fixtures/ocm-server-cernbox-grpc.toml index 5a0387c893..720c0925bd 100644 --- a/tests/integration/grpc/fixtures/ocm-server-cernbox-grpc.toml +++ b/tests/integration/grpc/fixtures/ocm-server-cernbox-grpc.toml @@ -20,12 +20,18 @@ driver = "static" basic = "{{grpc_address}}" [grpc.services.ocminvitemanager] -driver = "json" +driver = "{{ocm_driver}}" provider_domain = "cernbox.cern.ch" [grpc.services.ocminvitemanager.drivers.json] file = "{{invite_token_file}}" +[grpc.services.ocminvitemanager.drivers.sql] +db_username = "{{db_username}}" +db_password = "{{db_password}}" +db_address = "{{db_address}}" +db_name = "{{db_name}}" + [grpc.services.ocmproviderauthorizer] driver = "json" diff --git a/tests/integration/grpc/fixtures/ocm-server-cesnet-grpc.toml b/tests/integration/grpc/fixtures/ocm-server-cesnet-grpc.toml index 795855ba9f..77c43a6d2e 100644 --- a/tests/integration/grpc/fixtures/ocm-server-cesnet-grpc.toml +++ b/tests/integration/grpc/fixtures/ocm-server-cesnet-grpc.toml @@ -20,12 +20,18 @@ driver = "static" basic = "{{grpc_address}}" [grpc.services.ocminvitemanager] -driver = "json" +driver = "{{ocm_driver}}" provider_domain = "cesnet.cz" [grpc.services.ocminvitemanager.drivers.json] file = "{{invite_token_file}}" +[grpc.services.ocminvitemanager.drivers.sql] +db_username = "{{db_username}}" +db_password = "{{db_password}}" +db_address = "{{db_address}}" +db_name = "{{db_name}}" + [grpc.services.ocmproviderauthorizer] driver = "json" diff --git a/tests/integration/grpc/ocm_init_test.go b/tests/integration/grpc/ocm_init_test.go new file mode 100644 index 0000000000..03851eb716 --- /dev/null +++ b/tests/integration/grpc/ocm_init_test.go @@ -0,0 +1,181 @@ +// Copyright 2018-2023 CERN +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// In applying this license, CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +package grpc_test + +import ( + "database/sql" + "fmt" + "os" + "time" + + conversions "github.com/cs3org/reva/pkg/cbox/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/pkg/errors" + + userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" + invitepb "github.com/cs3org/go-cs3apis/cs3/ocm/invite/v1beta1" + "github.com/cs3org/reva/tests/helpers" + + _ "github.com/go-sql-driver/mysql" +) + +func initData(driver string, tokens []*invitepb.InviteToken, acceptedUsers map[string][]*userpb.User) (map[string]string, func(), error) { + variables := map[string]string{ + "ocm_driver": driver, + } + switch driver { + case "json": + return initJSONData(variables, tokens, acceptedUsers) + case "sql": + return initSQLData(variables, tokens, acceptedUsers) + } + + return nil, nil, errors.New("driver not found") +} + +func initJSONData(variables map[string]string, tokens []*invitepb.InviteToken, acceptedUsers map[string][]*userpb.User) (map[string]string, func(), error) { + data := map[string]any{} + + if len(tokens) != 0 { + m := map[string]*invitepb.InviteToken{} + for _, tkn := range tokens { + m[tkn.Token] = tkn + } + data["invites"] = m + } + + if len(acceptedUsers) != 0 { + data["accepted_users"] = acceptedUsers + } + + inviteTokenFile, err := helpers.TempJSONFile(data) + if err != nil { + return nil, nil, err + } + cleanup := func() { + Expect(os.RemoveAll(inviteTokenFile)).To(Succeed()) + } + variables["invite_token_file"] = inviteTokenFile + return variables, cleanup, nil +} + +func initTables(db *sql.DB) error { + table1 := ` +CREATE TABLE IF NOT EXISTS ocm_tokens ( + token VARCHAR(255) NOT NULL PRIMARY KEY, + initiator VARCHAR(255) NOT NULL, + expiration DATETIME NOT NULL, + description VARCHAR(255) DEFAULT NULL +)` + table2 := ` +CREATE TABLE IF NOT EXISTS ocm_remote_users ( + initiator VARCHAR(255) NOT NULL, + opaque_user_id VARCHAR(255) NOT NULL, + idp VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL, + display_name VARCHAR(255) NOT NULL, + PRIMARY KEY (initiator, opaque_user_id, idp) +)` + if _, err := db.Exec(table1); err != nil { + return err + } + if _, err := db.Exec(table2); err != nil { + return err + } + return nil +} + +func dropTables(db *sql.DB) error { + drop1 := "DROP TABLE IF EXISTS ocm_tokens" + drop2 := "DROP TABLE IF EXISTS ocm_remote_users" + if _, err := db.Exec(drop1); err != nil { + return err + } + if _, err := db.Exec(drop2); err != nil { + return err + } + return nil +} + +func initSQLData(variables map[string]string, tokens []*invitepb.InviteToken, acceptedUsers map[string][]*userpb.User) (map[string]string, func(), error) { + username := os.Getenv("SQL_USERNAME") + if username == "" { + Fail("SQL_USERNAME not set") + } + password := os.Getenv("SQL_PASSWORD") + if password == "" { + Fail("SQL_PASSWORD not set") + } + address := os.Getenv("SQL_ADDRESS") + if address == "" { + Fail("SQL_ADDRESS not set") + } + database := os.Getenv("SQL_DBNAME") + if database == "" { + Fail("SQL_DBNAME not set") + } + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", username, password, address, database)) + if err != nil { + return nil, nil, err + } + if err := initTables(db); err != nil { + return nil, nil, err + } + cleanup := func() { + Expect(dropTables(db)).To(Succeed()) + } + + variables["db_username"] = username + variables["db_password"] = password + variables["db_address"] = address + variables["db_name"] = database + + if err := initTokens(db, tokens); err != nil { + return nil, nil, err + } + if err := initAcceptedUsers(db, acceptedUsers); err != nil { + return nil, nil, err + } + + return variables, cleanup, nil +} + +func initTokens(db *sql.DB, tokens []*invitepb.InviteToken) error { + query := "INSERT INTO ocm_tokens (token, initiator, expiration, description) VALUES (?,?,?,?)" + for _, token := range tokens { + if _, err := db.Exec(query, token.Token, conversions.FormatUserID(token.UserId), time.Unix(int64(token.Expiration.Seconds), 0), token.Description); err != nil { + return err + } + } + return nil +} + +func initAcceptedUsers(db *sql.DB, acceptedUsers map[string][]*userpb.User) error { + query := "INSERT INTO ocm_remote_users (initiator, opaque_user_id, idp, email, display_name) VALUES (?,?,?,?,?)" + for initiator, users := range acceptedUsers { + for _, user := range users { + if _, err := db.Exec(query, initiator, user.Id.OpaqueId, user.Id.Idp, user.Mail, user.DisplayName); err != nil { + return err + } + } + } + return nil +} diff --git a/tests/integration/grpc/ocm_invitation_test.go b/tests/integration/grpc/ocm_invitation_test.go index 0cfcfb29cb..fbaec721f9 100644 --- a/tests/integration/grpc/ocm_invitation_test.go +++ b/tests/integration/grpc/ocm_invitation_test.go @@ -38,7 +38,6 @@ import ( "github.com/cs3org/reva/pkg/token" jwt "github.com/cs3org/reva/pkg/token/manager/jwt" "github.com/cs3org/reva/pkg/utils" - "github.com/cs3org/reva/tests/helpers" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "google.golang.org/grpc/metadata" @@ -124,341 +123,362 @@ var _ = Describe("ocm invitation workflow", func() { } ) - JustBeforeEach(func() { - tokenManager, err := jwt.New(map[string]interface{}{"secret": "changemeplease"}) - Expect(err).ToNot(HaveOccurred()) - ctxEinstein = ctxWithAuthToken(tokenManager, einstein) - ctxMarie = ctxWithAuthToken(tokenManager, marie) - revads, err = startRevads(map[string]string{ - "cernboxgw": "ocm-server-cernbox-grpc.toml", - "cernboxhttp": "ocm-server-cernbox-http.toml", - "cesnetgw": "ocm-server-cesnet-grpc.toml", - "cesnethttp": "ocm-server-cesnet-http.toml", - }, map[string]string{ - "providers": "ocm-providers.demo.json", - }, nil, variables) - Expect(err).ToNot(HaveOccurred()) - cernboxgw, err = pool.GetGatewayServiceClient(pool.Endpoint(revads["cernboxgw"].GrpcAddress)) - Expect(err).ToNot(HaveOccurred()) - cesnetgw, err = pool.GetGatewayServiceClient(pool.Endpoint(revads["cesnetgw"].GrpcAddress)) - Expect(err).ToNot(HaveOccurred()) - cernbox.Services[0].Endpoint.Path = "http://" + revads["cernboxhttp"].GrpcAddress + "/ocm" - }) - - AfterEach(func() { - for _, r := range revads { - Expect(r.Cleanup(CurrentGinkgoTestDescription().Failed)).To(Succeed()) - } - Expect(os.RemoveAll(inviteTokenFile)).To(Succeed()) - }) + for _, driver := range []string{"json", "sql"} { - Describe("einstein and marie do not know each other", func() { - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]string{}) + JustBeforeEach(func() { + tokenManager, err := jwt.New(map[string]interface{}{"secret": "changemeplease"}) + Expect(err).ToNot(HaveOccurred()) + ctxEinstein = ctxWithAuthToken(tokenManager, einstein) + ctxMarie = ctxWithAuthToken(tokenManager, marie) + variables["ocm_driver"] = driver + revads, err = startRevads(map[string]string{ + "cernboxgw": "ocm-server-cernbox-grpc.toml", + "cernboxhttp": "ocm-server-cernbox-http.toml", + "cesnetgw": "ocm-server-cesnet-grpc.toml", + "cesnethttp": "ocm-server-cesnet-http.toml", + }, map[string]string{ + "providers": "ocm-providers.demo.json", + }, nil, variables) + Expect(err).ToNot(HaveOccurred()) + cernboxgw, err = pool.GetGatewayServiceClient(pool.Endpoint(revads["cernboxgw"].GrpcAddress)) + Expect(err).ToNot(HaveOccurred()) + cesnetgw, err = pool.GetGatewayServiceClient(pool.Endpoint(revads["cesnetgw"].GrpcAddress)) Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, + cernbox.Services[0].Endpoint.Path = "http://" + revads["cernboxhttp"].GrpcAddress + "/ocm" + }) + + AfterEach(func() { + for _, r := range revads { + Expect(r.Cleanup(CurrentGinkgoTestDescription().Failed)).To(Succeed()) } + Expect(os.RemoveAll(inviteTokenFile)).To(Succeed()) }) - Context("einstein generates a token", func() { - It("will complete the workflow ", func() { - invitationTknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) + Describe("einstein and marie do not know each other", func() { + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, nil) Expect(err).ToNot(HaveOccurred()) - Expect(invitationTknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) - Expect(invitationTknRes.InviteToken).ToNot(BeNil()) - forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ - OriginSystemProvider: cernbox, - InviteToken: invitationTknRes.InviteToken, + }) + + AfterEach(func() { + cleanup() + }) + + Context("einstein generates a token", func() { + It("will complete the workflow ", func() { + invitationTknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(invitationTknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) + Expect(invitationTknRes.InviteToken).ToNot(BeNil()) + forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ + OriginSystemProvider: cernbox, + InviteToken: invitationTknRes.InviteToken, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) + + Expect(forwardRes.DisplayName).To(Equal(einstein.DisplayName)) + Expect(forwardRes.Email).To(Equal(einstein.Mail)) + Expect(utils.UserEqual(forwardRes.UserId, einstein.Id)).To(BeTrue()) + + usersRes1, err := cernboxgw.FindAcceptedUsers(ctxEinstein, &invitepb.FindAcceptedUsersRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(usersRes1.Status.Code).To(Equal(rpc.Code_CODE_OK)) + Expect(usersRes1.AcceptedUsers).To(HaveLen(1)) + info1 := usersRes1.AcceptedUsers[0] + Expect(ocmUserEqual(info1, marie)).To(BeTrue()) + + usersRes2, err := cesnetgw.FindAcceptedUsers(ctxMarie, &invitepb.FindAcceptedUsersRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(usersRes2.Status.Code).To(Equal(rpc.Code_CODE_OK)) + Expect(usersRes2.AcceptedUsers).To(HaveLen(1)) + info2 := usersRes2.AcceptedUsers[0] + Expect(ocmUserEqual(info2, einstein)).To(BeTrue()) }) - Expect(err).ToNot(HaveOccurred()) - Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) - Expect(forwardRes.DisplayName).To(Equal(einstein.DisplayName)) - Expect(forwardRes.Email).To(Equal(einstein.Mail)) - Expect(forwardRes.UserId).To(Equal(einstein.Id)) + }) + }) - usersRes1, err := cernboxgw.FindAcceptedUsers(ctxEinstein, &invitepb.FindAcceptedUsersRequest{}) + Describe("an invitation workflow has been already completed between einstein and marie", func() { + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, map[string][]*userpb.User{ + einstein.Id.OpaqueId: {marie}, + marie.Id.OpaqueId: {einstein}, + }) Expect(err).ToNot(HaveOccurred()) - Expect(usersRes1.Status.Code).To(Equal(rpc.Code_CODE_OK)) - Expect(usersRes1.AcceptedUsers).To(HaveLen(1)) - info1 := usersRes1.AcceptedUsers[0] - Expect(ocmUserEqual(info1, marie)).To(BeTrue()) + }) - usersRes2, err := cesnetgw.FindAcceptedUsers(ctxMarie, &invitepb.FindAcceptedUsersRequest{}) - Expect(err).ToNot(HaveOccurred()) - Expect(usersRes2.Status.Code).To(Equal(rpc.Code_CODE_OK)) - Expect(usersRes2.AcceptedUsers).To(HaveLen(1)) - info2 := usersRes2.AcceptedUsers[0] - Expect(ocmUserEqual(info2, einstein)).To(BeTrue()) + AfterEach(func() { + cleanup() }) + Context("marie accepts a new invite token generated by einstein", func() { + It("fails with already exists code", func() { + inviteTknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(inviteTknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) + + forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ + InviteToken: inviteTknRes.InviteToken, + OriginSystemProvider: cernbox, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_ALREADY_EXISTS)) + }) + }) }) - }) - Describe("an invitation workflow has been already completed between einstein and marie", func() { - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]map[string][]*userpb.User{ - "accepted_users": { - einstein.Id.OpaqueId: {marie}, - marie.Id.OpaqueId: {einstein}, + Describe("marie accepts an expired token", func() { + expiredToken := &invitepb.InviteToken{ + Token: "token", + UserId: einstein.Id, + Expiration: &typesv1beta1.Timestamp{ + Seconds: 0, }, - }) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, + Description: "expired token", } - }) - Context("marie accepts a new invite token generated by einstein", func() { - It("fails with already exists code", func() { - inviteTknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, []*invitepb.InviteToken{expiredToken}, nil) Expect(err).ToNot(HaveOccurred()) - Expect(inviteTknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) + }) + + AfterEach(func() { + cleanup() + }) + It("will not complete the invitation workflow", func() { forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ - InviteToken: inviteTknRes.InviteToken, + InviteToken: expiredToken, OriginSystemProvider: cernbox, }) Expect(err).ToNot(HaveOccurred()) - Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_ALREADY_EXISTS)) + Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_INVALID_ARGUMENT)) }) }) - }) - - Describe("marie accepts an expired token", func() { - expiredToken := &invitepb.InviteToken{ - Token: "token", - UserId: einstein.Id, - Expiration: &typesv1beta1.Timestamp{ - Seconds: 0, - }, - Description: "expired token", - } - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]map[string]*invitepb.InviteToken{ - "invites": { - expiredToken.Token: expiredToken, - }, - }) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, - } - }) - It("will not complete the invitation workflow", func() { - forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ - InviteToken: expiredToken, - OriginSystemProvider: cernbox, + Describe("marie accept a not existing token", func() { + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, nil) + Expect(err).ToNot(HaveOccurred()) }) - Expect(err).ToNot(HaveOccurred()) - Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_INVALID_ARGUMENT)) - }) - }) - Describe("marie accept a not existing token", func() { - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]string{}) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, - } - }) + AfterEach(func() { + cleanup() + }) - It("will not complete the invitation workflow", func() { - forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ - InviteToken: &invitepb.InviteToken{ - Token: "not-existing-token", - }, - OriginSystemProvider: cernbox, + It("will not complete the invitation workflow", func() { + forwardRes, err := cesnetgw.ForwardInvite(ctxMarie, &invitepb.ForwardInviteRequest{ + InviteToken: &invitepb.InviteToken{ + Token: "not-existing-token", + }, + OriginSystemProvider: cernbox, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_NOT_FOUND)) }) - Expect(err).ToNot(HaveOccurred()) - Expect(forwardRes.Status.Code).To(Equal(rpc.Code_CODE_NOT_FOUND)) }) - }) - Context("clients use the http endpoints exposed by sciencemesh", func() { - var ( - cesnetURL string - cernboxURL string - tknMarie, tknEinstein string - token string - ) + Context("clients use the http endpoints exposed by sciencemesh", func() { + var ( + cesnetURL string + cernboxURL string + tknMarie, tknEinstein string + token string + ) - JustBeforeEach(func() { - cesnetURL = revads["cesnethttp"].GrpcAddress - cernboxURL = revads["cernboxhttp"].GrpcAddress + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + cleanup() + }) - var ok bool - tknMarie, ok = ctxpkg.ContextGetToken(ctxMarie) - Expect(ok).To(BeTrue()) - tknEinstein, ok = ctxpkg.ContextGetToken(ctxEinstein) - Expect(ok).To(BeTrue()) + JustBeforeEach(func() { + cesnetURL = revads["cesnethttp"].GrpcAddress + cernboxURL = revads["cernboxhttp"].GrpcAddress - tknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) - Expect(err).ToNot(HaveOccurred()) - Expect(tknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) - token = tknRes.InviteToken.Token - }) + var ok bool + tknMarie, ok = ctxpkg.ContextGetToken(ctxMarie) + Expect(ok).To(BeTrue()) + tknEinstein, ok = ctxpkg.ContextGetToken(ctxEinstein) + Expect(ok).To(BeTrue()) - acceptInvite := func(revaToken, domain, provider, token string) int { - d, err := json.Marshal(map[string]string{ - "token": token, - "providerDomain": provider, + tknRes, err := cernboxgw.GenerateInviteToken(ctxEinstein, &invitepb.GenerateInviteTokenRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(tknRes.Status.Code).To(Equal(rpc.Code_CODE_OK)) + token = tknRes.InviteToken.Token }) - Expect(err).ToNot(HaveOccurred()) - req, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, fmt.Sprintf("http://%s/sciencemesh/accept-invite", domain), bytes.NewReader(d)) - Expect(err).ToNot(HaveOccurred()) - req.Header.Set("x-access-token", revaToken) - req.Header.Set("content-type", "application/json") - res, err := http.DefaultClient.Do(req) - Expect(err).ToNot(HaveOccurred()) - defer res.Body.Close() + acceptInvite := func(revaToken, domain, provider, token string) int { + d, err := json.Marshal(map[string]string{ + "token": token, + "providerDomain": provider, + }) + Expect(err).ToNot(HaveOccurred()) + req, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, fmt.Sprintf("http://%s/sciencemesh/accept-invite", domain), bytes.NewReader(d)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("x-access-token", revaToken) + req.Header.Set("content-type", "application/json") - return res.StatusCode - } + res, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer res.Body.Close() - findAccepted := func(revaToken, domain string) ([]*userpb.User, int) { - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("http://%s/sciencemesh/find-accepted-users", domain), nil) - Expect(err).ToNot(HaveOccurred()) - req.Header.Set("x-access-token", revaToken) + return res.StatusCode + } - res, err := http.DefaultClient.Do(req) - Expect(err).ToNot(HaveOccurred()) - defer res.Body.Close() + findAccepted := func(revaToken, domain string) ([]*userpb.User, int) { + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("http://%s/sciencemesh/find-accepted-users", domain), nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("x-access-token", revaToken) - var users []*userpb.User - _ = json.NewDecoder(res.Body).Decode(&users) - return users, res.StatusCode - } + res, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer res.Body.Close() - generateToken := func(revaToken, domain string) (*generateInviteResponse, int) { - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("http://%s/sciencemesh/generate-invite", domain), nil) - Expect(err).ToNot(HaveOccurred()) - req.Header.Set("x-access-token", revaToken) + var users []*userpb.User + _ = json.NewDecoder(res.Body).Decode(&users) + return users, res.StatusCode + } - res, err := http.DefaultClient.Do(req) - Expect(err).ToNot(HaveOccurred()) - defer res.Body.Close() + generateToken := func(revaToken, domain string) (*generateInviteResponse, int) { + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("http://%s/sciencemesh/generate-invite", domain), nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("x-access-token", revaToken) - var inviteRes generateInviteResponse - Expect(json.NewDecoder(res.Body).Decode(&inviteRes)).To(Succeed()) - return &inviteRes, res.StatusCode - } + res, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer res.Body.Close() + + var inviteRes generateInviteResponse + Expect(json.NewDecoder(res.Body).Decode(&inviteRes)).To(Succeed()) + return &inviteRes, res.StatusCode + } - Context("einstein and marie do not know each other", func() { + Context("einstein and marie do not know each other", func() { - Context("marie is not logged-in", func() { - It("fails with permission denied", func() { - code := acceptInvite("", cesnetURL, "cernbox.cern.ch", token) - Expect(code).To(Equal(http.StatusUnauthorized)) + Context("marie is not logged-in", func() { + It("fails with permission denied", func() { + code := acceptInvite("", cesnetURL, "cernbox.cern.ch", token) + Expect(code).To(Equal(http.StatusUnauthorized)) + }) }) - }) - It("complete the invitation workflow", func() { - users, code := findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) + It("complete the invitation workflow", func() { + users, code := findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) - code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", token) - Expect(code).To(Equal(http.StatusOK)) + code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", token) + Expect(code).To(Equal(http.StatusOK)) - users, code = findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + users, code = findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + }) }) - }) - Context("marie already accepted an invitation before", func() { - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]map[string][]*userpb.User{ - "accepted_users": { + Context("marie already accepted an invitation before", func() { + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, map[string][]*userpb.User{ einstein.Id.OpaqueId: {marie}, marie.Id.OpaqueId: {einstein}, - }, + }) + Expect(err).ToNot(HaveOccurred()) }) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, - } - }) - It("fails the invitation workflow", func() { - users, code := findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + AfterEach(func() { + cleanup() + }) - code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", token) - Expect(code).To(Equal(http.StatusConflict)) + It("fails the invitation workflow", func() { + users, code := findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) - users, code = findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", token) + Expect(code).To(Equal(http.StatusConflict)) + + users, code = findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + }) }) - }) - Context("marie uses an expired token", func() { - expiredToken := &invitepb.InviteToken{ - Token: "token", - UserId: einstein.Id, - Expiration: &typesv1beta1.Timestamp{ - Seconds: 0, - }, - Description: "expired token", - } - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]map[string]*invitepb.InviteToken{ - "invites": { - expiredToken.Token: expiredToken, + Context("marie uses an expired token", func() { + expiredToken := &invitepb.InviteToken{ + Token: "token", + UserId: einstein.Id, + Expiration: &typesv1beta1.Timestamp{ + Seconds: 0, }, - }) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, + Description: "expired token", } - }) - It("will not complete the invitation workflow", func() { - users, code := findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, []*invitepb.InviteToken{expiredToken}, nil) + Expect(err).ToNot(HaveOccurred()) + }) - code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", expiredToken.Token) - Expect(code).To(Equal(http.StatusBadRequest)) + AfterEach(func() { + cleanup() + }) - users, code = findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) - }) - }) + It("will not complete the invitation workflow", func() { + users, code := findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) - Context("generate the token from http apis", func() { - BeforeEach(func() { - inviteTokenFile, err = helpers.TempJSONFile(map[string]map[string]*invitepb.InviteToken{}) - Expect(err).ToNot(HaveOccurred()) - variables = map[string]string{ - "invite_token_file": inviteTokenFile, - } + code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", expiredToken.Token) + Expect(code).To(Equal(http.StatusBadRequest)) + + users, code = findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) + }) }) - It("succeeds", func() { - users, code := findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) - ocmToken, code := generateToken(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) + Context("generate the token from http apis", func() { + var cleanup func() + BeforeEach(func() { + variables, cleanup, err = initData(driver, nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) - code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", ocmToken.Token) - Expect(code).To(Equal(http.StatusOK)) + AfterEach(func() { + cleanup() + }) + + It("succeeds", func() { + users, code := findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{})).To(BeTrue()) + + ocmToken, code := generateToken(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + + code = acceptInvite(tknMarie, cesnetURL, "cernbox.cern.ch", ocmToken.Token) + Expect(code).To(Equal(http.StatusOK)) - users, code = findAccepted(tknEinstein, cernboxURL) - Expect(code).To(Equal(http.StatusOK)) - Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + users, code = findAccepted(tknEinstein, cernboxURL) + Expect(code).To(Equal(http.StatusOK)) + Expect(ocmUsersEqual(users, []*userpb.User{marie})).To(BeTrue()) + }) }) + }) - }) + } + }) func ocmUsersEqual(u1, u2 []*userpb.User) bool {