diff --git a/occupi-backend/.dev.env.gpg b/occupi-backend/.dev.env.gpg index 0c4d9bb3..2ceb8bc6 100644 Binary files a/occupi-backend/.dev.env.gpg and b/occupi-backend/.dev.env.gpg differ diff --git a/occupi-backend/.env.gpg b/occupi-backend/.env.gpg index a81501fe..187f8513 100644 Binary files a/occupi-backend/.env.gpg and b/occupi-backend/.env.gpg differ diff --git a/occupi-backend/.prod.env.gpg b/occupi-backend/.prod.env.gpg index 777eb2b2..fdcdf599 100644 Binary files a/occupi-backend/.prod.env.gpg and b/occupi-backend/.prod.env.gpg differ diff --git a/occupi-backend/configs/config.go b/occupi-backend/configs/config.go index 18ec3a50..b9392cf6 100644 --- a/occupi-backend/configs/config.go +++ b/occupi-backend/configs/config.go @@ -165,3 +165,12 @@ func GetSessionSecret() string { } return secret } + +func GetOccupiDomains() []string { + domains := os.Getenv("OCCUPI_DOMAINS") + if domains != "" { + domainList := strings.Split(domains, ",") + return domainList + } + return []string{""} +} diff --git a/occupi-backend/pkg/authenticator/auth.go b/occupi-backend/pkg/authenticator/auth.go index fba96622..03f07db4 100644 --- a/occupi-backend/pkg/authenticator/auth.go +++ b/occupi-backend/pkg/authenticator/auth.go @@ -17,8 +17,15 @@ type Claims struct { } // GenerateToken generates a JWT token for the user -func GenerateToken(email string, role string) (string, time.Time, error) { - expirationTime := time.Now().Add(5 * time.Minute) +func GenerateToken(email string, role string, optionalExpiryTime ...time.Duration) (string, time.Time, error) { + var expirationTime time.Time + + if len(optionalExpiryTime) == 0 { + expirationTime = time.Now().Add(24 * 7 * time.Hour) + } else { + expirationTime = time.Now().Add(optionalExpiryTime[0]) + } + claims := &Claims{ Email: email, Role: role, diff --git a/occupi-backend/pkg/handlers/auth_handlers.go b/occupi-backend/pkg/handlers/auth_handlers.go index 004cf145..766785d8 100644 --- a/occupi-backend/pkg/handlers/auth_handlers.go +++ b/occupi-backend/pkg/handlers/auth_handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "time" + "github.com/COS301-SE-2024/occupi/occupi-backend/configs" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/authenticator" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/constants" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/database" @@ -149,7 +150,7 @@ func Login(ctx *gin.Context, appsession *models.AppSession, role string) { if role == constants.Admin { token, expirationTime, err = authenticator.GenerateToken(requestUser.Email, constants.Admin) } else { - token, expirationTime, err = authenticator.GenerateToken(requestUser.Email, "user") + token, expirationTime, err = authenticator.GenerateToken(requestUser.Email, constants.Basic) } if err != nil { @@ -389,7 +390,7 @@ func ResetPassword(ctx *gin.Context, appsession *models.AppSession) { // this will contain reset password logic } -// handler for logging out a user on occupi /auth/logout TODO: complete implementation +// handler for logging out a user func Logout(ctx *gin.Context) { session := sessions.Default(ctx) session.Clear() @@ -399,7 +400,15 @@ func Logout(ctx *gin.Context) { return } - ctx.SetCookie("token", "", -1, "/", "localhost", false, true) + // List of domains to clear cookies from + domains := configs.GetOccupiDomains() + + // Iterate over each domain and clear the "token" and "occupi-sessions-store" cookies + for _, domain := range domains { + ctx.SetCookie("token", "", -1, "/", domain, false, true) + ctx.SetCookie("occupi-sessions-store", "", -1, "/", domain, false, true) + } + ctx.JSON(http.StatusOK, utils.SuccessResponse( http.StatusOK, "Logged out successfully!", diff --git a/occupi-backend/pkg/middleware/middleware.go b/occupi-backend/pkg/middleware/middleware.go index b1c0f905..f0f4105e 100644 --- a/occupi-backend/pkg/middleware/middleware.go +++ b/occupi-backend/pkg/middleware/middleware.go @@ -39,21 +39,38 @@ func ProtectedRoute(ctx *gin.Context) { http.StatusUnauthorized, "Bad Request", constants.InvalidAuthCode, - "User not authorized", + "Invalid token", nil)) ctx.Abort() return } + // check if email and role session variables are set session := sessions.Default(ctx) - session.Set("email", claims.Email) - session.Set("role", claims.Role) - if err := session.Save(); err != nil { - ctx.JSON(http.StatusInternalServerError, utils.InternalServerError()) - logrus.Error(err) + if session.Get("email") == nil || session.Get("role") == nil { + session.Set("email", claims.Email) + session.Set("role", claims.Role) + if err := session.Save(); err != nil { + ctx.JSON(http.StatusInternalServerError, utils.InternalServerError()) + logrus.Error(err) + ctx.Abort() + return + } + } + + // check that session variables and token claims match + if session.Get("email") != claims.Email || session.Get("role") != claims.Role { + ctx.JSON(http.StatusUnauthorized, + utils.ErrorResponse( + http.StatusUnauthorized, + "Bad Request", + constants.InvalidAuthCode, + "Inalid auth session", + nil)) ctx.Abort() return } + ctx.Next() } @@ -77,6 +94,19 @@ func UnProtectedRoute(ctx *gin.Context) { } } + // check if email and role session variables are set + session := sessions.Default(ctx) + if session.Get("email") != nil || session.Get("role") != nil { + session.Delete("email") + session.Delete("role") + if err := session.Save(); err != nil { + ctx.JSON(http.StatusInternalServerError, utils.InternalServerError()) + logrus.Error(err) + ctx.Abort() + return + } + } + ctx.Next() } diff --git a/occupi-backend/tests/authenticator_test.go b/occupi-backend/tests/authenticator_test.go index 7a706377..12a5ee16 100644 --- a/occupi-backend/tests/authenticator_test.go +++ b/occupi-backend/tests/authenticator_test.go @@ -18,13 +18,13 @@ func TestGenerateToken(t *testing.T) { t.Fatal("Error loading .env file: ", err) } - email := "test@example.com" + email := "test1@example.com" role := constants.Admin tokenString, expirationTime, err := authenticator.GenerateToken(email, role) require.NoError(t, err) require.NotEmpty(t, tokenString) - require.WithinDuration(t, time.Now().Add(5*time.Minute), expirationTime, time.Second) + require.WithinDuration(t, time.Now().Add(24*7*time.Hour), expirationTime, time.Second) // Validate the token claims, err := authenticator.ValidateToken(tokenString) @@ -40,7 +40,7 @@ func TestValidateToken(t *testing.T) { t.Fatal("Error loading .env file: ", err) } - email := "test@example.com" + email := "test2@example.com" role := constants.Admin tokenString, _, err := authenticator.GenerateToken(email, role) @@ -60,3 +60,26 @@ func TestValidateToken(t *testing.T) { require.Error(t, err) assert.Nil(t, claims) } + +func TestValidateTokenExpired(t *testing.T) { + // Load environment variables from .env file + if err := godotenv.Load("../.env"); err != nil { + t.Fatal("Error loading .env file: ", err) + } + + email := "test3@example.com" + role := constants.Admin + + // Generate a token that expires in 1 second + tokenString, _, err := authenticator.GenerateToken(email, role, 1*time.Second) + require.NoError(t, err) + require.NotEmpty(t, tokenString) + + // Wait for the token to expire + time.Sleep(2 * time.Second) + + // Validate the token + claims, err := authenticator.ValidateToken(tokenString) + require.Error(t, err) + assert.Nil(t, claims) +} diff --git a/occupi-backend/tests/handlers_test.go b/occupi-backend/tests/handlers_test.go index 0cb9bf8b..5fce01f9 100644 --- a/occupi-backend/tests/handlers_test.go +++ b/occupi-backend/tests/handlers_test.go @@ -14,6 +14,9 @@ import ( "github.com/gin-gonic/gin" + "github.com/COS301-SE-2024/occupi/occupi-backend/configs" + "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/authenticator" + "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/constants" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/database" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/middleware" "github.com/COS301-SE-2024/occupi/occupi-backend/pkg/router" @@ -161,9 +164,6 @@ func TestPingRoute(t *testing.T) { // Create a Gin router ginRouter := gin.Default() - // adding rate limiting middleware - middleware.AttachRateLimitMiddleware(ginRouter) - // Register routes router.OccupiRouter(ginRouter, db) @@ -361,3 +361,199 @@ func TestRateLimitWithMultipleIPs(t *testing.T) { // Assertions for IP2 assert.Equal(t, rateLimitedCountIP2, 0, "There should be no requests from IP2 that are rate limited") } + +func TestInvalidLogoutHandler(t *testing.T) { + // Load environment variables from .env file + if err := godotenv.Load("../.env"); err != nil { + t.Fatal("Error loading .env file: ", err) + } + + // setup logger to log all server interactions + utils.SetupLogger() + + // connect to the database + db := database.ConnectToDatabase() + + // set gin run mode + gin.SetMode("test") + + // Create a Gin router + ginRouter := gin.Default() + + // Register routes + router.OccupiRouter(ginRouter, db) + + // Create a request to pass to the handler + req, err := http.NewRequest("POST", "/auth/logout", nil) + if err != nil { + t.Fatal("Error creating request: ", err) + } + + // Record the HTTP response + rr := httptest.NewRecorder() + + // Serve the request + ginRouter.ServeHTTP(rr, req) + + // Check the status code is what we expect + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized) + } +} + +func TestValidLogoutHandler(t *testing.T) { + // Load environment variables from .env file + if err := godotenv.Load("../.env"); err != nil { + t.Fatal("Error loading .env file: ", err) + } + + // setup logger to log all server interactions + utils.SetupLogger() + + // connect to the database + db := database.ConnectToDatabase() + + // set gin run mode + gin.SetMode("test") + + // Create a Gin router + ginRouter := gin.Default() + + // Register routes + router.OccupiRouter(ginRouter, db) + + // Create a request to pass to the handler + req, err := http.NewRequest("POST", "/auth/logout", nil) + if err != nil { + t.Fatal("Error creating request: ", err) + } + + // Set up cookies for the request, "token" and "occupi-sessions-store" + token, _, err := authenticator.GenerateToken("example@gmail.com", constants.Basic) + if err != nil { + t.Fatal("Error generating token: ", err) + } + cookie1 := http.Cookie{ + Name: "token", + Value: token, + } + req.AddCookie(&cookie1) + + // Record the HTTP response + rr := httptest.NewRecorder() + + // Serve the request + ginRouter.ServeHTTP(rr, req) + + // Check the status code is what we expect + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + + // ensure that protected route cannot be accessed like ping-auth + req, err = http.NewRequest("GET", "/ping-auth", nil) + + if err != nil { + t.Fatal("Error creating request: ", err) + } + + // record the HTTP response + rr = httptest.NewRecorder() + + // serve the request + ginRouter.ServeHTTP(rr, req) + + // check the status code is what we expect + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized) + } +} + +func TestValidLogoutHandlerFromDomains(t *testing.T) { + // Load environment variables from .env file + if err := godotenv.Load("../.env"); err != nil { + t.Fatal("Error loading .env file: ", err) + } + + // setup logger to log all server interactions + utils.SetupLogger() + + // connect to the database + db := database.ConnectToDatabase() + + // set gin run mode + gin.SetMode("test") + + // Create a Gin router + ginRouter := gin.Default() + + // Register routes + router.OccupiRouter(ginRouter, db) + + // read domains + domains := configs.GetOccupiDomains() + + // use a wait group to handle concurrency + var wg sync.WaitGroup + + for _, domain := range domains { + wg.Add(1) + + go func(domain string) { + defer wg.Done() + + // Create a request to pass to the handler + req, err := http.NewRequest("POST", "/auth/logout", nil) + if err != nil { + t.Errorf("Error creating request: %v", err) + return + } + + // set the domain + req.Host = domain + + // Set up cookies for the request, "token" and "occupi-sessions-store" + token, _, err := authenticator.GenerateToken("example@gmail.com", constants.Basic) + if err != nil { + t.Errorf("Error generating token: %s", err) + } + cookie1 := http.Cookie{ + Name: "token", + Value: token, + } + req.AddCookie(&cookie1) + + // Record the HTTP response + rr := httptest.NewRecorder() + + // Serve the request + ginRouter.ServeHTTP(rr, req) + + // Check the status code is what we expect + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code for domain %s: got %v want %v", domain, status, http.StatusOK) + } + + // ensure that protected route cannot be accessed like ping-auth + req, err = http.NewRequest("GET", "/ping-auth", nil) + + if err != nil { + t.Errorf("Error creating request: %s", err) + } + + // record the HTTP response + rr = httptest.NewRecorder() + + // serve the request + ginRouter.ServeHTTP(rr, req) + + // check the status code is what we expect + if status := rr.Code; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v for domain: %s", status, http.StatusUnauthorized, domain) + } + }(domain) + } + + // Wait for all goroutines to finish + wg.Wait() +}