diff --git a/main.go b/main.go index 3e3a005dd9..d3ed52a372 100644 --- a/main.go +++ b/main.go @@ -141,6 +141,7 @@ func main() { sessionRegistry := server.NewLocalSessionRegistry(metrics) sessionCache := server.NewLocalSessionCache(config.GetSession().TokenExpirySec) consoleSessionCache := server.NewLocalSessionCache(config.GetConsole().TokenExpirySec) + loginAttemptCache := server.NewLocalLoginAttemptCache() statusRegistry := server.NewStatusRegistry(logger, config, sessionRegistry, jsonpbMarshaler) tracker := server.StartLocalTracker(logger, config, sessionRegistry, statusRegistry, metrics, jsonpbMarshaler) router := server.NewLocalMessageRouter(sessionRegistry, tracker, jsonpbMarshaler) @@ -166,7 +167,7 @@ func main() { statusHandler := server.NewLocalStatusHandler(logger, sessionRegistry, matchRegistry, tracker, metrics, config.GetName()) apiServer := server.StartApiServer(logger, startupLogger, db, jsonpbMarshaler, jsonpbUnmarshaler, config, socialClient, leaderboardCache, leaderboardRankCache, sessionRegistry, sessionCache, statusRegistry, matchRegistry, matchmaker, tracker, router, streamManager, metrics, pipeline, runtime) - consoleServer := server.StartConsoleServer(logger, startupLogger, db, config, tracker, router, streamManager, sessionCache, consoleSessionCache, statusRegistry, statusHandler, runtimeInfo, matchRegistry, configWarnings, semver, leaderboardCache, leaderboardRankCache, apiServer, cookie) + consoleServer := server.StartConsoleServer(logger, startupLogger, db, config, tracker, router, streamManager, sessionCache, consoleSessionCache, loginAttemptCache, statusRegistry, statusHandler, runtimeInfo, matchRegistry, configWarnings, semver, leaderboardCache, leaderboardRankCache, apiServer, cookie) gaenabled := len(os.Getenv("NAKAMA_TELEMETRY")) < 1 const gacode = "UA-89792135-1" @@ -232,6 +233,7 @@ func main() { sessionCache.Stop() sessionRegistry.Stop() metrics.Stop(logger) + loginAttemptCache.Stop() if gaenabled { _ = ga.SendSessionStop(telemetryClient, gacode, cookie) diff --git a/server/console.go b/server/console.go index 56b26eb062..215fe5d82c 100644 --- a/server/console.go +++ b/server/console.go @@ -138,6 +138,7 @@ type ConsoleServer struct { StreamManager StreamManager sessionCache SessionCache consoleSessionCache SessionCache + loginAttemptCache LoginAttemptCache statusRegistry *StatusRegistry matchRegistry MatchRegistry statusHandler StatusHandler @@ -155,7 +156,7 @@ type ConsoleServer struct { httpClient *http.Client } -func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, db *sql.DB, config Config, tracker Tracker, router MessageRouter, streamManager StreamManager, sessionCache SessionCache, consoleSessionCache SessionCache, statusRegistry *StatusRegistry, statusHandler StatusHandler, runtimeInfo *RuntimeInfo, matchRegistry MatchRegistry, configWarnings map[string]string, serverVersion string, leaderboardCache LeaderboardCache, leaderboardRankCache LeaderboardRankCache, api *ApiServer, cookie string) *ConsoleServer { +func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, db *sql.DB, config Config, tracker Tracker, router MessageRouter, streamManager StreamManager, sessionCache SessionCache, consoleSessionCache SessionCache, loginAttemptCache LoginAttemptCache, statusRegistry *StatusRegistry, statusHandler StatusHandler, runtimeInfo *RuntimeInfo, matchRegistry MatchRegistry, configWarnings map[string]string, serverVersion string, leaderboardCache LeaderboardCache, leaderboardRankCache LeaderboardRankCache, api *ApiServer, cookie string) *ConsoleServer { var gatewayContextTimeoutMs string if config.GetConsole().IdleTimeoutMs > 500 { // Ensure the GRPC Gateway timeout is just under the idle timeout (if possible) to ensure it has priority. @@ -182,6 +183,7 @@ func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, db *sql.D StreamManager: streamManager, sessionCache: sessionCache, consoleSessionCache: consoleSessionCache, + loginAttemptCache: loginAttemptCache, statusRegistry: statusRegistry, matchRegistry: matchRegistry, statusHandler: statusHandler, diff --git a/server/console_authenticate.go b/server/console_authenticate.go index b967e678e0..5ee3157380 100644 --- a/server/console_authenticate.go +++ b/server/console_authenticate.go @@ -20,10 +20,9 @@ import ( "database/sql" "errors" "fmt" - "github.com/gofrs/uuid" - "google.golang.org/protobuf/types/known/emptypb" "time" + "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt/v4" "github.com/heroiclabs/nakama/v3/console" "github.com/jackc/pgtype" @@ -31,6 +30,7 @@ import ( "golang.org/x/crypto/bcrypt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" ) type ConsoleTokenClaims struct { @@ -71,6 +71,11 @@ func parseConsoleToken(hmacSecretByte []byte, tokenString string) (id, username, } func (s *ConsoleServer) Authenticate(ctx context.Context, in *console.AuthenticateRequest) (*console.ConsoleSession, error) { + ip, _ := extractClientAddressFromContext(s.logger, ctx) + if !s.loginAttemptCache.Allow(in.Username, ip) { + return nil, status.Error(codes.ResourceExhausted, "Try again later.") + } + role := console.UserRole_USER_ROLE_UNKNOWN var uname string var email string @@ -81,10 +86,20 @@ func (s *ConsoleServer) Authenticate(ctx context.Context, in *console.Authentica role = console.UserRole_USER_ROLE_ADMIN uname = in.Username id = uuid.Nil + } else { + if lockout, until := s.loginAttemptCache.Add(s.config.GetConsole().Username, ip); lockout != LockoutTypeNone { + switch lockout { + case LockoutTypeAccount: + s.logger.Info(fmt.Sprintf("Console admin account locked until %v.", until)) + case LockoutTypeIp: + s.logger.Info(fmt.Sprintf("Console admin IP locked until %v.", until)) + } + } + return nil, status.Error(codes.Unauthenticated, "Invalid credentials.") } default: var err error - id, uname, email, role, err = s.lookupConsoleUser(ctx, in.Username, in.Password) + id, uname, email, role, err = s.lookupConsoleUser(ctx, in.Username, in.Password, ip) if err != nil { return nil, err } @@ -94,7 +109,10 @@ func (s *ConsoleServer) Authenticate(ctx context.Context, in *console.Authentica return nil, status.Error(codes.Unauthenticated, "Invalid credentials.") } + s.loginAttemptCache.Reset(uname) + exp := time.Now().UTC().Add(time.Duration(s.config.GetConsole().TokenExpirySec) * time.Second).Unix() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ConsoleTokenClaims{ ExpiresAt: exp, ID: id.String(), @@ -132,7 +150,7 @@ func (s *ConsoleServer) AuthenticateLogout(ctx context.Context, in *console.Auth return &emptypb.Empty{}, nil } -func (s *ConsoleServer) lookupConsoleUser(ctx context.Context, unameOrEmail, password string) (id uuid.UUID, uname string, email string, role console.UserRole, err error) { +func (s *ConsoleServer) lookupConsoleUser(ctx context.Context, unameOrEmail, password, ip string) (id uuid.UUID, uname string, email string, role console.UserRole, err error) { role = console.UserRole_USER_ROLE_UNKNOWN query := "SELECT id, username, email, role, password, disable_time FROM console_user WHERE username = $1 OR email = $1" var dbPassword []byte @@ -140,11 +158,20 @@ func (s *ConsoleServer) lookupConsoleUser(ctx context.Context, unameOrEmail, pas err = s.db.QueryRowContext(ctx, query, unameOrEmail).Scan(&id, &uname, &email, &role, &dbPassword, &dbDisableTime) if err != nil { if err == sql.ErrNoRows { - err = nil + if lockout, until := s.loginAttemptCache.Add("", ip); lockout == LockoutTypeIp { + s.logger.Info(fmt.Sprintf("Console user IP locked until %v.", until)) + } + err = status.Error(codes.Unauthenticated, "Invalid credentials.") } return } + // Check lockout again as the login attempt may have been through email. + if !s.loginAttemptCache.Allow(uname, ip) { + err = status.Error(codes.ResourceExhausted, "Try again later.") + return + } + // Check if it's disabled. if dbDisableTime.Status == pgtype.Present && dbDisableTime.Time.Unix() != 0 { s.logger.Info("Console user account is disabled.", zap.String("username", unameOrEmail)) @@ -155,6 +182,14 @@ func (s *ConsoleServer) lookupConsoleUser(ctx context.Context, unameOrEmail, pas // Check password err = bcrypt.CompareHashAndPassword(dbPassword, []byte(password)) if err != nil { + if lockout, until := s.loginAttemptCache.Add(uname, ip); lockout != LockoutTypeNone { + switch lockout { + case LockoutTypeAccount: + s.logger.Info(fmt.Sprintf("Console user account locked until %v.", until)) + case LockoutTypeIp: + s.logger.Info(fmt.Sprintf("Console user IP locked until %v.", until)) + } + } err = status.Error(codes.Unauthenticated, "Invalid credentials.") return } diff --git a/server/login_attempt_cache.go b/server/login_attempt_cache.go new file mode 100644 index 0000000000..3bb8821cec --- /dev/null +++ b/server/login_attempt_cache.go @@ -0,0 +1,174 @@ +// Copyright 2022 The Nakama Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "sync" + "time" +) + +type LockoutType uint8 + +const ( + LockoutTypeNone LockoutType = iota + LockoutTypeAccount + LockoutTypeIp +) + +const ( + maxAttemptsAccount = 5 + lockoutPeriodAccount = time.Minute * 1 + + maxAttemptsIp = 10 + lockoutPeriodIp = time.Minute * 10 +) + +type LoginAttemptCache interface { + Stop() + // Allow checks whether account or IP is locked out or should be allowed to attempt to authenticate. + Allow(account, ip string) bool + // Add a failed attempt and return current lockout status. + Add(account, ip string) (LockoutType, time.Time) + // Reset account attempts on successful login. + Reset(account string) +} + +type lockoutStatus struct { + lockedUntil time.Time + attempts []time.Time +} + +func (ls *lockoutStatus) trim(now time.Time, retentionPeriod time.Duration) bool { + if ls.lockedUntil.Before(now) { + ls.lockedUntil = time.Time{} + } + for i := len(ls.attempts) - 1; i >= 0; i-- { + if now.Sub(ls.attempts[i]) >= retentionPeriod { + ls.attempts = ls.attempts[i+1:] + break + } + } + + return ls.lockedUntil.IsZero() && len(ls.attempts) == 0 +} + +type LocalLoginAttemptCache struct { + sync.RWMutex + ctx context.Context + ctxCancelFn context.CancelFunc + + accountCache map[string]*lockoutStatus + ipCache map[string]*lockoutStatus +} + +func NewLocalLoginAttemptCache() LoginAttemptCache { + ctx, ctxCancelFn := context.WithCancel(context.Background()) + + c := &LocalLoginAttemptCache{ + accountCache: make(map[string]*lockoutStatus), + ipCache: make(map[string]*lockoutStatus), + + ctx: ctx, + ctxCancelFn: ctxCancelFn, + } + + go func() { + ticker := time.NewTicker(10 * time.Minute) + for { + select { + case <-c.ctx.Done(): + ticker.Stop() + return + case t := <-ticker.C: + now := t.UTC() + c.Lock() + for account, status := range c.accountCache { + if status.trim(now, lockoutPeriodAccount) { + delete(c.accountCache, account) + } + } + for ip, status := range c.ipCache { + if status.trim(now, lockoutPeriodIp) { + delete(c.ipCache, ip) + } + } + c.Unlock() + } + } + }() + + return c +} + +func (c *LocalLoginAttemptCache) Stop() { + c.ctxCancelFn() +} + +func (c *LocalLoginAttemptCache) Allow(account, ip string) bool { + now := time.Now().UTC() + c.RLock() + defer c.RUnlock() + if status, found := c.accountCache[account]; found && !status.lockedUntil.IsZero() && status.lockedUntil.After(now) { + return false + } + if status, found := c.ipCache[ip]; found && !status.lockedUntil.IsZero() && status.lockedUntil.After(now) { + return false + } + return true +} + +func (c *LocalLoginAttemptCache) Reset(account string) { + c.Lock() + delete(c.accountCache, account) + c.Unlock() +} + +func (c *LocalLoginAttemptCache) Add(account, ip string) (LockoutType, time.Time) { + now := time.Now().UTC() + var lockoutType LockoutType + var lockedUntil time.Time + c.Lock() + defer c.Unlock() + if account != "" { + status, found := c.accountCache[account] + if !found { + status = &lockoutStatus{} + c.accountCache[account] = status + } + status.attempts = append(status.attempts, now) + _ = status.trim(now, lockoutPeriodAccount) + if len(status.attempts) >= maxAttemptsAccount { + status.lockedUntil = now.Add(lockoutPeriodAccount) + lockedUntil = status.lockedUntil + lockoutType = LockoutTypeAccount + } + } + if ip != "" { + status, found := c.ipCache[ip] + if !found { + status = &lockoutStatus{} + c.ipCache[ip] = status + } + status.attempts = append(status.attempts, now) + _ = status.trim(now, lockoutPeriodIp) + if len(status.attempts) >= maxAttemptsIp { + status.lockedUntil = now.Add(lockoutPeriodIp) + lockedUntil = status.lockedUntil + lockoutType = LockoutTypeIp + } + } + return lockoutType, lockedUntil +}