Skip to content

Commit

Permalink
feat: Add server timeouts, health check, and improved graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
presbrey committed Jan 27, 2025
1 parent b20fbff commit 6737e0b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 30 deletions.
103 changes: 77 additions & 26 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"

"os/signal"

Expand All @@ -23,12 +24,18 @@ import (

// Options holds the application configuration
type Options struct {
HTTPAddr string `env:"HTTP_ADDR" envDefault:""`
Verbose bool `env:"VERBOSE" envDefault:"false"`
LogRequests bool `env:"LOG_REQUESTS" envDefault:"false"`
LogResponses bool `env:"LOG_RESPONSES" envDefault:"false"`
LogErrors bool `env:"LOG_ERRORS" envDefault:"true"`
LogFile string `env:"LOG_FILE" envDefault:""`
HTTPAddr string `env:"HTTP_ADDR" envDefault:""`
ReadTimeout time.Duration `env:"READ_TIMEOUT" envDefault:"30s"`
WriteTimeout time.Duration `env:"WRITE_TIMEOUT" envDefault:"30s"`
IdleTimeout time.Duration `env:"IDLE_TIMEOUT" envDefault:"60s"`
ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"60s"`
Verbose bool `env:"VERBOSE" envDefault:"false"`
LogRequests bool `env:"LOG_REQUESTS" envDefault:"false"`
LogResponses bool `env:"LOG_RESPONSES" envDefault:"false"`
LogErrors bool `env:"LOG_ERRORS" envDefault:"true"`
LogFile string `env:"LOG_FILE" envDefault:""`
HealthCheck string `env:"HEALTH_CHECK" envDefault:"/health"`
RetryAttempts int `env:"RETRY_ATTEMPTS" envDefault:"3"`
}

func isWebScheme(s string) bool {
Expand All @@ -50,6 +57,14 @@ func NewLightyMux(opts *Options) (*LightyMux, error) {
opts = &Options{}
}

// Set default values if not provided
if opts.HealthCheck == "" {
opts.HealthCheck = "/health"
}
if opts.RetryAttempts == 0 {
opts.RetryAttempts = 3
}

// Configure logging
var logWriter io.Writer = os.Stdout
if opts.LogFile != "" {
Expand All @@ -67,29 +82,60 @@ func NewLightyMux(opts *Options) (*LightyMux, error) {
muxLock: sync.RWMutex{},
}

// Set up HTTP server
// Add health check endpoint
l.mux.HandleFunc(opts.HealthCheck, l.handleHealthCheck)

// Set up HTTP server with timeouts
l.server = &http.Server{
Addr: opts.HTTPAddr,
Handler: l,
Addr: opts.HTTPAddr,
Handler: l,
ReadTimeout: opts.ReadTimeout, // 0 means no timeout
WriteTimeout: opts.WriteTimeout, // 0 means no timeout
IdleTimeout: opts.IdleTimeout, // 0 means no timeout
}

return l, nil
}

func (lm *LightyMux) newReverseProxy(nextHop *url.URL) *httputil.ReverseProxy {
transport := &http.Transport{
ResponseHeaderTimeout: lm.options.ProxyTimeout,
MaxIdleConnsPerHost: 100,
}

rp := &httputil.ReverseProxy{
Transport: transport,
Director: func(req *http.Request) {
req.URL.Scheme = nextHop.Scheme
req.URL.Host = nextHop.Host
req.URL.Path = singleJoiningSlash(nextHop.Path, req.URL.Path)
if lm.options.Verbose {
lm.logger.Printf("Proxying request to: %s", req.URL.String())
}
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
if lm.options.LogErrors {
lm.logger.Printf("Proxy error: %v", err)
}
w.WriteHeader(http.StatusBadGateway)
fmt.Fprintf(w, "Proxy Error: %v", err)
},
}

if lm.options.LogResponses {
rp.ModifyResponse = lm.modifyResponse
}

return rp
}

// handleHealthCheck handles the health check endpoint
func (lm *LightyMux) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"healthy","timestamp":"%s"}`, time.Now().Format(time.RFC3339))
}

func (lm *LightyMux) loadConfig(filename string) error {
file, err := os.Open(filename)
if err != nil {
Expand Down Expand Up @@ -221,36 +267,41 @@ func singleJoiningSlash(a, b string) string {
}

func (lm *LightyMux) Run(configFile string) error {
// Initialize proxy handler with config
if err := lm.loadConfig(configFile); err != nil {
return fmt.Errorf("error loading config: %v", err)
return fmt.Errorf("failed to load config: %v", err)
}

// Start config file watcher
if err := lm.watchConfig(configFile); err != nil {
return err
return fmt.Errorf("failed to watch config: %v", err)
}

// Setup graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up graceful shutdown
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()

// Start server in a goroutine
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
<-sigChan
lm.logger.Println("Shutting down server...")
if err := lm.server.Shutdown(ctx); err != nil {
lm.logger.Printf("Error during server shutdown: %v", err)
if err := lm.server.ListenAndServe(); err != http.ErrServerClosed {
lm.logger.Printf("HTTP server error: %v", err)
}
cancel()
}()

lm.logger.Printf("Starting reverse proxy on %s", lm.options.HTTPAddr)
if err := lm.server.ListenAndServe(); err != http.ErrServerClosed {
return fmt.Errorf("error starting server: %v", err)
lm.logger.Printf("Server started on %s", lm.server.Addr)

// Wait for interrupt signal
<-ctx.Done()
lm.logger.Println("Shutting down server...")

// Create shutdown context with timeout
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Attempt graceful shutdown
if err := lm.server.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)
}

lm.logger.Println("Server gracefully stopped")
return nil
}

Expand Down
28 changes: 24 additions & 4 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,23 @@ func TestNewLightyMux(t *testing.T) {
{
name: "Custom address",
opts: &Options{
HTTPAddr: ":9090",
HTTPAddr: ":8080",
},
},
{
name: "With logging options",
opts: &Options{
LogFile: "test.log",
LogRequests: true,
LogResponses: true,
LogErrors: true,
},
},
{
name: "Zero timeouts disable timeouts",
opts: &Options{
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: 0,
},
},
}
Expand All @@ -258,8 +266,20 @@ func TestNewLightyMux(t *testing.T) {
t.Errorf("NewLightyMux() error = %v, wantErr %v", err, tt.wantErr)
return
}
if lm == nil {
t.Error("NewLightyMux() returned nil LightyMux")
if err != nil {
return
}

if tt.name == "Zero timeouts disable timeouts" {
if lm.server.ReadTimeout != 0 {
t.Errorf("ReadTimeout = %v, want 0", lm.server.ReadTimeout)
}
if lm.server.WriteTimeout != 0 {
t.Errorf("WriteTimeout = %v, want 0", lm.server.WriteTimeout)
}
if lm.server.IdleTimeout != 0 {
t.Errorf("IdleTimeout = %v, want 0", lm.server.IdleTimeout)
}
}
})
}
Expand Down

0 comments on commit 6737e0b

Please sign in to comment.