Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session improvements #510

Merged
merged 9 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ docker-compose up
| `BIND_ADDRESS` | The addresses that can access to the web interface and the port, use unix:///abspath/to/file.socket for unix domain socket. | 0.0.0.0:80 |
| `SESSION_SECRET` | The secret key used to encrypt the session cookies. Set this to a random value | N/A |
| `SESSION_SECRET_FILE` | Optional filepath for the secret key used to encrypt the session cookies. Leave `SESSION_SECRET` blank to take effect | N/A |
| `SESSION_MAX_DURATION` | Max time in days a remembered session is refreshed and valid. Non-refreshed session is valid for 7 days max, regardless of this setting. | 90 |
| `SUBNET_RANGES` | The list of address subdivision ranges. Format: `SR Name:10.0.1.0/24; SR2:10.0.2.0/24,10.0.3.0/24` Each CIDR must be inside one of the server interfaces. | N/A |
| `WGUI_USERNAME` | The username for the login page. Used for db initialization only | `admin` |
| `WGUI_PASSWORD` | The password for the user on the login page. Will be hashed automatically. Used for db initialization only | `admin` |
Expand Down
23 changes: 16 additions & 7 deletions handler/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,41 @@ func Login(db store.IStore) echo.HandlerFunc {
}

if userCorrect && passwordCorrect {
// TODO: refresh the token
ageMax := 0
expiration := time.Now().Add(24 * time.Hour)
if rememberMe {
ageMax = 86400
expiration.Add(144 * time.Hour)
ageMax = 86400 * 7
}

cookiePath := util.GetCookiePath()

sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
Path: util.BasePath,
Path: cookiePath,
MaxAge: ageMax,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}

// set session_token
tokenUID := xid.New().String()
now := time.Now().UTC().Unix()
sess.Values["username"] = dbuser.Username
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
sess.Values["admin"] = dbuser.Admin
sess.Values["session_token"] = tokenUID
sess.Values["max_age"] = ageMax
sess.Values["created_at"] = now
sess.Values["updated_at"] = now
sess.Save(c.Request(), c.Response())

// set session_token in cookie
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = tokenUID
cookie.Expires = expiration
cookie.MaxAge = ageMax
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)

return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"})
Expand Down Expand Up @@ -257,7 +266,7 @@ func UpdateUser(db store.IStore) echo.HandlerFunc {
log.Infof("Updated user information successfully")

if previousUsername == currentUser(c) {
setUser(c, user.Username, user.Admin)
setUser(c, user.Username, user.Admin, util.GetDBUserCRC32(user))
}

return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated user information successfully"})
Expand Down
168 changes: 167 additions & 1 deletion handler/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package handler
import (
"fmt"
"net/http"
"time"

"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/ngoduykhanh/wireguard-ui/util"
Expand All @@ -23,6 +25,15 @@ func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
}
}

// RefreshSession must only be used after ValidSession middleware
// RefreshSession checks if the session is eligible for the refresh, but doesn't check if it's fully valid
func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
doRefreshSession(c)
return next(c)
}
}

func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isAdmin(c) {
Expand All @@ -41,9 +52,146 @@ func isValidSession(c echo.Context) bool {
if err != nil || sess.Values["session_token"] != cookie.Value {
return false
}

// Check time bounds
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
maxAge := getMaxAge(sess)
// Temporary session is considered valid within 24h if browser is not closed before
// This value is not saved and is used as virtual expiration
if maxAge == 0 {
maxAge = 86400
}
expiration := updatedAt + int64(maxAge)
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || createdAt+util.SessionMaxDuration < now {
return false
}

// Check if user still exists and unchanged
username := fmt.Sprintf("%s", sess.Values["username"])
userHash := getUserHash(sess)
if uHash, ok := util.DBUsersToCRC32[username]; !ok || userHash != uHash {
return false
}

return true
}

// Refreshes a "remember me" session when the user visits web pages (not API)
// Session must be valid before calling this function
// Refresh is performed at most once per 24h
func doRefreshSession(c echo.Context) {
if util.DisableLogin {
return
}

sess, _ := session.Get("session", c)
maxAge := getMaxAge(sess)
if maxAge <= 0 {
return
}

oldCookie, err := c.Cookie("session_token")
if err != nil || sess.Values["session_token"] != oldCookie.Value {
return
}

// Refresh no sooner than 24h
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
expiration := updatedAt + int64(getMaxAge(sess))
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || now-updatedAt < 86_400 || createdAt+util.SessionMaxDuration < now {
return
}

cookiePath := util.GetCookiePath()

sess.Values["updated_at"] = now
sess.Options = &sessions.Options{
Path: cookiePath,
MaxAge: maxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
sess.Save(c.Request(), c.Response())

cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = oldCookie.Value
cookie.MaxAge = maxAge
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}

// Get time in seconds this session is valid without updating
func getMaxAge(sess *sessions.Session) int {
if util.DisableLogin {
return 0
}

maxAge := sess.Values["max_age"]

switch typedMaxAge := maxAge.(type) {
case int:
return typedMaxAge
default:
return 0
}
}

// Get a timestamp in seconds of the time the session was created
func getCreatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}

createdAt := sess.Values["created_at"]

switch typedCreatedAt := createdAt.(type) {
case int64:
return typedCreatedAt
default:
return 0
}
}

// Get a timestamp in seconds of the last session update
func getUpdatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}

lastUpdate := sess.Values["updated_at"]

switch typedLastUpdate := lastUpdate.(type) {
case int64:
return typedLastUpdate
default:
return 0
}
}

// Get CRC32 of a user at the moment of log in
// Any changes to user will result in logout of other (not updated) sessions
func getUserHash(sess *sessions.Session) uint32 {
if util.DisableLogin {
return 0
}

userHash := sess.Values["user_hash"]

switch typedUserHash := userHash.(type) {
case uint32:
return typedUserHash
default:
return 0
}
}

// currentUser to get username of logged in user
func currentUser(c echo.Context) string {
if util.DisableLogin {
Expand All @@ -66,9 +214,10 @@ func isAdmin(c echo.Context) bool {
return admin == "true"
}

func setUser(c echo.Context, username string, admin bool) {
func setUser(c echo.Context, username string, admin bool, userCRC32 uint32) {
sess, _ := session.Get("session", c)
sess.Values["username"] = username
sess.Values["user_hash"] = userCRC32
sess.Values["admin"] = admin
sess.Save(c.Request(), c.Response())
}
Expand All @@ -77,7 +226,24 @@ func setUser(c echo.Context, username string, admin bool) {
func clearSession(c echo.Context) {
sess, _ := session.Get("session", c)
sess.Values["username"] = ""
sess.Values["user_hash"] = 0
sess.Values["admin"] = false
sess.Values["session_token"] = ""
sess.Values["max_age"] = -1
sess.Options.MaxAge = -1
sess.Save(c.Request(), c.Response())

cookiePath := util.GetCookiePath()

cookie, err := c.Cookie("session_token")
if err != nil {
cookie = new(http.Cookie)
}

cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.MaxAge = -1
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}
20 changes: 12 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"crypto/sha512"
"embed"
"flag"
"fmt"
Expand Down Expand Up @@ -48,6 +49,7 @@ var (
flagTelegramAllowConfRequest = false
flagTelegramFloodWait = 60
flagSessionSecret = util.RandomString(32)
flagSessionMaxDuration = 90
flagWgConfTemplate string
flagBasePath string
flagSubnetRanges string
Expand Down Expand Up @@ -92,6 +94,7 @@ func init() {
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
flag.IntVar(&flagSessionMaxDuration, "session-max-duration", util.LookupEnvOrInt("SESSION_MAX_DURATION", flagSessionMaxDuration), "Max time in days a remembered session is refreshed and valid.")

var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
Expand Down Expand Up @@ -136,7 +139,8 @@ func init() {
util.SendgridApiKey = flagSendgridApiKey
util.EmailFrom = flagEmailFrom
util.EmailFromName = flagEmailFromName
util.SessionSecret = []byte(flagSessionSecret)
util.SessionSecret = sha512.Sum512([]byte(flagSessionSecret))
util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // Store in seconds
util.WgConfTemplate = flagWgConfTemplate
util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
Expand Down Expand Up @@ -205,7 +209,7 @@ func main() {
// register routes
app := router.New(tmplDir, extraData, util.SessionSecret)

app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession)
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession)

// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
Expand All @@ -215,8 +219,8 @@ func main() {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
Expand All @@ -242,19 +246,19 @@ func main() {
app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession)
app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession)
app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession)
app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession)
app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession)
app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession, handler.RefreshSession)
app.POST(util.BasePath+"/wake_on_lan_host", handler.SaveWakeOnLanHost(db), handler.ValidSession, handler.ContentTypeJson)
app.DELETE(util.BasePath+"/wake_on_lan_host/:mac_address", handler.DeleteWakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)
app.PUT(util.BasePath+"/wake_on_lan_host/:mac_address", handler.WakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)
Expand Down
12 changes: 10 additions & 2 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c
}

// New function
func New(tmplDir fs.FS, extraData map[string]interface{}, secret []byte) *echo.Echo {
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo {
e := echo.New()
e.Use(session.Middleware(sessions.NewCookieStore(secret)))

cookiePath := util.GetCookiePath()

cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = cookiePath
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)

e.Use(session.Middleware(cookieStore))

// read html template file to string
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")
Expand Down
Loading