diff --git a/internal/config/config.go b/internal/config/config.go index 73106fbde..3c10a9472 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,6 +30,8 @@ type Config struct { NavbarTitle string Env map[string]string TLS *TLS + IsAuthToken bool + AuthToken string } type TLS struct { @@ -71,7 +73,8 @@ func LoadConfig(userHomeDir string) error { _ = viper.BindEnv("navbarTitle", "DAGU_NAVBAR_TITLE") _ = viper.BindEnv("tls.certFile", "DAGU_CERT_FILE") _ = viper.BindEnv("tls.keyFile", "DAGU_KEY_FILE") - + _ = viper.BindEnv("isAuthToken", "DAGU_IS_AUTHTOKEN") + _ = viper.BindEnv("authToken", "DAGU_AUTHTOKEN") command := "dagu" if ex, err := os.Executable(); err == nil { command = ex @@ -93,6 +96,8 @@ func LoadConfig(userHomeDir string) error { viper.SetDefault("adminLogsDir", path.Join(appHome, "logs", "admin")) viper.SetDefault("navbarColor", "") viper.SetDefault("navbarTitle", "Dagu") + viper.SetDefault("isAuthToken", "0") + viper.SetDefault("authToken", "") viper.AutomaticEnv() diff --git a/output.diff b/output.diff new file mode 100644 index 000000000..9fcacebbd --- /dev/null +++ b/output.diff @@ -0,0 +1,138 @@ +diff --git a/internal/config/config.go b/internal/config/config.go +index 73106fb..3c10a94 100644 +--- a/internal/config/config.go ++++ b/internal/config/config.go +@@ -30,6 +30,8 @@ type Config struct { + NavbarTitle string + Env map[string]string + TLS *TLS ++ IsAuthToken bool ++ AuthToken string + } + + type TLS struct { +@@ -71,7 +73,8 @@ func LoadConfig(userHomeDir string) error { + _ = viper.BindEnv("navbarTitle", "DAGU_NAVBAR_TITLE") + _ = viper.BindEnv("tls.certFile", "DAGU_CERT_FILE") + _ = viper.BindEnv("tls.keyFile", "DAGU_KEY_FILE") +- ++ _ = viper.BindEnv("isAuthToken", "DAGU_IS_AUTHTOKEN") ++ _ = viper.BindEnv("authToken", "DAGU_AUTHTOKEN") + command := "dagu" + if ex, err := os.Executable(); err == nil { + command = ex +@@ -93,6 +96,8 @@ func LoadConfig(userHomeDir string) error { + viper.SetDefault("adminLogsDir", path.Join(appHome, "logs", "admin")) + viper.SetDefault("navbarColor", "") + viper.SetDefault("navbarTitle", "Dagu") ++ viper.SetDefault("isAuthToken", "0") ++ viper.SetDefault("authToken", "") + + viper.AutomaticEnv() + +diff --git a/service/frontend/fx.go b/service/frontend/fx.go +index 8be9f69..fce5449 100644 +--- a/service/frontend/fx.go ++++ b/service/frontend/fx.go +@@ -3,6 +3,7 @@ package frontend + import ( + "context" + "embed" ++ + "github.com/dagu-dev/dagu/internal/config" + "github.com/dagu-dev/dagu/internal/logger" + "github.com/dagu-dev/dagu/service/frontend/handlers" +@@ -53,6 +54,13 @@ func New(params Params) *server.Server { + AssetsFS: assetsFS, + } + ++ if params.Config.IsAuthToken { ++ ++ serverParams.AuthToken = &server.AuthToken{ ++ Token: params.Config.AuthToken, ++ } ++ } ++ + if params.Config.IsBasicAuth { + serverParams.BasicAuth = &server.BasicAuth{ + Username: params.Config.BasicAuthUsername, +diff --git a/service/frontend/server/server.go b/service/frontend/server/server.go +index eca1b3c..d6eb8bc 100644 +--- a/service/frontend/server/server.go ++++ b/service/frontend/server/server.go +@@ -3,17 +3,18 @@ package server + import ( + "context" + "errors" ++ "io/fs" ++ "net/http" ++ "os" ++ "os/signal" ++ "syscall" ++ + "github.com/dagu-dev/dagu/internal/config" + "github.com/dagu-dev/dagu/internal/logger" + "github.com/dagu-dev/dagu/internal/logger/tag" + "github.com/dagu-dev/dagu/service/frontend/restapi" + "github.com/go-openapi/loads" + flags "github.com/jessevdk/go-flags" +- "io/fs" +- "net/http" +- "os" +- "os/signal" +- "syscall" + + pkgmiddleware "github.com/dagu-dev/dagu/service/frontend/middleware" + "github.com/dagu-dev/dagu/service/frontend/restapi/operations" +@@ -26,10 +27,15 @@ type BasicAuth struct { + Password string + } + ++type AuthToken struct { ++ Token string ++} ++ + type Params struct { + Host string + Port int + BasicAuth *BasicAuth ++ AuthToken *AuthToken + TLS *config.TLS + Logger logger.Logger + Handlers []New +@@ -40,6 +46,7 @@ type Server struct { + host string + port int + basicAuth *BasicAuth ++ authToken *AuthToken + tls *config.TLS + logger logger.Logger + server *restapi.Server +@@ -56,6 +63,7 @@ func NewServer(params Params) *Server { + host: params.Host, + port: params.Port, + basicAuth: params.BasicAuth, ++ authToken: params.AuthToken, + tls: params.TLS, + logger: params.Logger, + handlers: params.Handlers, +@@ -77,6 +85,11 @@ func (svr *Server) Serve(ctx context.Context) (err error) { + middlewareOptions := &pkgmiddleware.Options{ + Handler: svr.defaultRoutes(chi.NewRouter()), + } ++ if svr.authToken != nil { ++ middlewareOptions.AuthToken = &pkgmiddleware.AuthToken{ ++ Token: svr.authToken.Token, ++ } ++ } + if svr.basicAuth != nil { + middlewareOptions.BasicAuth = &pkgmiddleware.BasicAuth{ + Username: svr.basicAuth.Username, +@@ -90,7 +103,6 @@ func (svr *Server) Serve(ctx context.Context) (err error) { + svr.logger.Error("failed to load API spec", tag.Error(err)) + return err + } +- + api := operations.NewDaguAPI(swaggerSpec) + for _, h := range svr.handlers { + h.Configure(api) diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 8be9f6925..fce54497f 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -3,6 +3,7 @@ package frontend import ( "context" "embed" + "github.com/dagu-dev/dagu/internal/config" "github.com/dagu-dev/dagu/internal/logger" "github.com/dagu-dev/dagu/service/frontend/handlers" @@ -53,6 +54,13 @@ func New(params Params) *server.Server { AssetsFS: assetsFS, } + if params.Config.IsAuthToken { + + serverParams.AuthToken = &server.AuthToken{ + Token: params.Config.AuthToken, + } + } + if params.Config.IsBasicAuth { serverParams.BasicAuth = &server.BasicAuth{ Username: params.Config.BasicAuthUsername, diff --git a/service/frontend/middleware/global.go b/service/frontend/middleware/global.go index 09a9b9cb9..a1723e2b5 100644 --- a/service/frontend/middleware/global.go +++ b/service/frontend/middleware/global.go @@ -1,9 +1,10 @@ package middleware import ( - "github.com/go-chi/chi/v5/middleware" "net/http" "strings" + + "github.com/go-chi/chi/v5/middleware" ) func SetupGlobalMiddleware(handler http.Handler) http.Handler { @@ -12,12 +13,16 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler { next = middleware.Logger(next) next = middleware.Recoverer(next) + if authToken != nil { + next = TokenAuth("bearer", authToken.Token)(next) + basicAuth = nil + } + if basicAuth != nil { next = middleware.BasicAuth( "restricted", map[string]string{basicAuth.Username: basicAuth.Password}, )(next) } - next = prefixChecker(next) return next @@ -26,11 +31,13 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler { var ( defaultHandler http.Handler basicAuth *BasicAuth + authToken *AuthToken ) type Options struct { Handler http.Handler BasicAuth *BasicAuth + AuthToken *AuthToken } type BasicAuth struct { @@ -38,9 +45,14 @@ type BasicAuth struct { Password string } +type AuthToken struct { + Token string +} + func Setup(opts *Options) { defaultHandler = opts.Handler basicAuth = opts.BasicAuth + authToken = opts.AuthToken } func prefixChecker(next http.Handler) http.Handler { diff --git a/service/frontend/middleware/tokenAuth.go b/service/frontend/middleware/tokenAuth.go new file mode 100644 index 000000000..78fb43e5d --- /dev/null +++ b/service/frontend/middleware/tokenAuth.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "crypto/subtle" + "fmt" + "net/http" +) + +// TokenAuth implements a similar middleware handler like go-chi's BasicAuth middleware but for bearer tokens +func TokenAuth(realm string, token string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bearer := r.Header.Get("Bearer") + if bearer == "" { + tokenAuthFailed(w, realm) + return + } + + if subtle.ConstantTimeCompare([]byte(bearer), []byte(token)) != 1 { + tokenAuthFailed(w, realm) + return + } + + next.ServeHTTP(w, r) + + }) + } +} + +func tokenAuthFailed(w http.ResponseWriter, realm string) { + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) + w.WriteHeader(http.StatusUnauthorized) +} diff --git a/service/frontend/server/server.go b/service/frontend/server/server.go index eca1b3c4b..d6eb8bc40 100644 --- a/service/frontend/server/server.go +++ b/service/frontend/server/server.go @@ -3,17 +3,18 @@ package server import ( "context" "errors" + "io/fs" + "net/http" + "os" + "os/signal" + "syscall" + "github.com/dagu-dev/dagu/internal/config" "github.com/dagu-dev/dagu/internal/logger" "github.com/dagu-dev/dagu/internal/logger/tag" "github.com/dagu-dev/dagu/service/frontend/restapi" "github.com/go-openapi/loads" flags "github.com/jessevdk/go-flags" - "io/fs" - "net/http" - "os" - "os/signal" - "syscall" pkgmiddleware "github.com/dagu-dev/dagu/service/frontend/middleware" "github.com/dagu-dev/dagu/service/frontend/restapi/operations" @@ -26,10 +27,15 @@ type BasicAuth struct { Password string } +type AuthToken struct { + Token string +} + type Params struct { Host string Port int BasicAuth *BasicAuth + AuthToken *AuthToken TLS *config.TLS Logger logger.Logger Handlers []New @@ -40,6 +46,7 @@ type Server struct { host string port int basicAuth *BasicAuth + authToken *AuthToken tls *config.TLS logger logger.Logger server *restapi.Server @@ -56,6 +63,7 @@ func NewServer(params Params) *Server { host: params.Host, port: params.Port, basicAuth: params.BasicAuth, + authToken: params.AuthToken, tls: params.TLS, logger: params.Logger, handlers: params.Handlers, @@ -77,6 +85,11 @@ func (svr *Server) Serve(ctx context.Context) (err error) { middlewareOptions := &pkgmiddleware.Options{ Handler: svr.defaultRoutes(chi.NewRouter()), } + if svr.authToken != nil { + middlewareOptions.AuthToken = &pkgmiddleware.AuthToken{ + Token: svr.authToken.Token, + } + } if svr.basicAuth != nil { middlewareOptions.BasicAuth = &pkgmiddleware.BasicAuth{ Username: svr.basicAuth.Username, @@ -90,7 +103,6 @@ func (svr *Server) Serve(ctx context.Context) (err error) { svr.logger.Error("failed to load API spec", tag.Error(err)) return err } - api := operations.NewDaguAPI(swaggerSpec) for _, h := range svr.handlers { h.Configure(api)