diff --git a/README.md b/README.md index 9f9cf2c2..dabcbcbc 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,31 @@ http: # Make sure to use https:// if you are using TLS. public_url: "http://localhost:23232" + # The cross-origin request security options + cors: + # The allowed cross-origin headers + allowed_headers: + - Accept + - Accept-Language + - Content-Language + - Origin + # - Content-Type + # - X-Requested-With + # - User-Agent + # - Authorization + + # The allowed cross-origin URLs + # allowed_origins: + # - * + + # The allowed cross-origin methods + allowed_methods: + - GET + - HEAD + - POST + # - PUT + # - OPTIONS + # The database configuration. db: # The database driver to use. diff --git a/pkg/config/config.go b/pkg/config/config.go index 103598a6..f71fd202 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -55,6 +55,15 @@ type GitConfig struct { MaxConnections int `env:"MAX_CONNECTIONS" yaml:"max_connections"` } +// CORSConfig is the CORS configuration for the server. +type CORSConfig struct { + AllowedHeaders []string `env:"ALLOWED_HEADERS" yaml:"allowed_headers"` + + AllowedOrigins []string `env:"ALLOWED_ORIGINS" yaml:"allowed_origins"` + + AllowedMethods []string `env:"ALLOWED_METHODS" yaml:"allowed_methods"` +} + // HTTPConfig is the HTTP configuration for the server. type HTTPConfig struct { // ListenAddr is the address on which the HTTP server will listen. @@ -68,6 +77,9 @@ type HTTPConfig struct { // PublicURL is the public URL of the HTTP server. PublicURL string `env:"PUBLIC_URL" yaml:"public_url"` + + // HTTP is the configuration for the HTTP server. + CORS CORSConfig `envPrefix:"CORS_" yaml:"cors"` } // StatsConfig is the configuration for the stats server. @@ -180,6 +192,9 @@ func (c *Config) Environ() []string { fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath), fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath), fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_HEADERS=%s", strings.Join(c.HTTP.CORS.AllowedHeaders, ",")), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_ORIGINS=%s", strings.Join(c.HTTP.CORS.AllowedOrigins, ",")), + fmt.Sprintf("SOFT_SERVE_HTTP_CORS_ALLOWED_METHODS=%s", strings.Join(c.HTTP.CORS.AllowedMethods, ",")), fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr), fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format), fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat), diff --git a/pkg/web/server.go b/pkg/web/server.go index 74a04f5b..ab336e89 100644 --- a/pkg/web/server.go +++ b/pkg/web/server.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/pkg/config" "github.com/gorilla/handlers" "github.com/gorilla/mux" ) @@ -26,5 +27,12 @@ func NewRouter(ctx context.Context) http.Handler { h = handlers.CompressHandler(h) h = handlers.RecoveryHandler()(h) + cfg := config.FromContext(ctx) + + h = handlers.CORS(handlers.AllowedHeaders(cfg.HTTP.CORS.AllowedHeaders), + handlers.AllowedOrigins(cfg.HTTP.CORS.AllowedOrigins), + handlers.AllowedMethods(cfg.HTTP.CORS.AllowedMethods), + )(h) + return h }