From 6737e0b94548b5aed14b6c6b4b079b49888bf342 Mon Sep 17 00:00:00 2001 From: presbrey Date: Mon, 27 Jan 2025 16:15:06 -0500 Subject: [PATCH] feat: Add server timeouts, health check, and improved graceful shutdown --- main.go | 103 ++++++++++++++++++++++++++++++++++++++------------- main_test.go | 28 ++++++++++++-- 2 files changed, 101 insertions(+), 30 deletions(-) diff --git a/main.go b/main.go index 46fcfda..c058d01 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strings" "sync" + "time" "os/signal" @@ -23,12 +24,18 @@ import ( // Options holds the application configuration type Options struct { - HTTPAddr string `env:"HTTP_ADDR" envDefault:""` - Verbose bool `env:"VERBOSE" envDefault:"false"` - LogRequests bool `env:"LOG_REQUESTS" envDefault:"false"` - LogResponses bool `env:"LOG_RESPONSES" envDefault:"false"` - LogErrors bool `env:"LOG_ERRORS" envDefault:"true"` - LogFile string `env:"LOG_FILE" envDefault:""` + HTTPAddr string `env:"HTTP_ADDR" envDefault:""` + ReadTimeout time.Duration `env:"READ_TIMEOUT" envDefault:"30s"` + WriteTimeout time.Duration `env:"WRITE_TIMEOUT" envDefault:"30s"` + IdleTimeout time.Duration `env:"IDLE_TIMEOUT" envDefault:"60s"` + ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"60s"` + Verbose bool `env:"VERBOSE" envDefault:"false"` + LogRequests bool `env:"LOG_REQUESTS" envDefault:"false"` + LogResponses bool `env:"LOG_RESPONSES" envDefault:"false"` + LogErrors bool `env:"LOG_ERRORS" envDefault:"true"` + LogFile string `env:"LOG_FILE" envDefault:""` + HealthCheck string `env:"HEALTH_CHECK" envDefault:"/health"` + RetryAttempts int `env:"RETRY_ATTEMPTS" envDefault:"3"` } func isWebScheme(s string) bool { @@ -50,6 +57,14 @@ func NewLightyMux(opts *Options) (*LightyMux, error) { opts = &Options{} } + // Set default values if not provided + if opts.HealthCheck == "" { + opts.HealthCheck = "/health" + } + if opts.RetryAttempts == 0 { + opts.RetryAttempts = 3 + } + // Configure logging var logWriter io.Writer = os.Stdout if opts.LogFile != "" { @@ -67,29 +82,60 @@ func NewLightyMux(opts *Options) (*LightyMux, error) { muxLock: sync.RWMutex{}, } - // Set up HTTP server + // Add health check endpoint + l.mux.HandleFunc(opts.HealthCheck, l.handleHealthCheck) + + // Set up HTTP server with timeouts l.server = &http.Server{ - Addr: opts.HTTPAddr, - Handler: l, + Addr: opts.HTTPAddr, + Handler: l, + ReadTimeout: opts.ReadTimeout, // 0 means no timeout + WriteTimeout: opts.WriteTimeout, // 0 means no timeout + IdleTimeout: opts.IdleTimeout, // 0 means no timeout } return l, nil } func (lm *LightyMux) newReverseProxy(nextHop *url.URL) *httputil.ReverseProxy { + transport := &http.Transport{ + ResponseHeaderTimeout: lm.options.ProxyTimeout, + MaxIdleConnsPerHost: 100, + } + rp := &httputil.ReverseProxy{ + Transport: transport, Director: func(req *http.Request) { req.URL.Scheme = nextHop.Scheme req.URL.Host = nextHop.Host req.URL.Path = singleJoiningSlash(nextHop.Path, req.URL.Path) + if lm.options.Verbose { + lm.logger.Printf("Proxying request to: %s", req.URL.String()) + } + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + if lm.options.LogErrors { + lm.logger.Printf("Proxy error: %v", err) + } + w.WriteHeader(http.StatusBadGateway) + fmt.Fprintf(w, "Proxy Error: %v", err) }, } + if lm.options.LogResponses { rp.ModifyResponse = lm.modifyResponse } + return rp } +// handleHealthCheck handles the health check endpoint +func (lm *LightyMux) handleHealthCheck(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"healthy","timestamp":"%s"}`, time.Now().Format(time.RFC3339)) +} + func (lm *LightyMux) loadConfig(filename string) error { file, err := os.Open(filename) if err != nil { @@ -221,36 +267,41 @@ func singleJoiningSlash(a, b string) string { } func (lm *LightyMux) Run(configFile string) error { - // Initialize proxy handler with config if err := lm.loadConfig(configFile); err != nil { - return fmt.Errorf("error loading config: %v", err) + return fmt.Errorf("failed to load config: %v", err) } - // Start config file watcher if err := lm.watchConfig(configFile); err != nil { - return err + return fmt.Errorf("failed to watch config: %v", err) } - // Setup graceful shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // Set up graceful shutdown + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + // Start server in a goroutine go func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) - <-sigChan - lm.logger.Println("Shutting down server...") - if err := lm.server.Shutdown(ctx); err != nil { - lm.logger.Printf("Error during server shutdown: %v", err) + if err := lm.server.ListenAndServe(); err != http.ErrServerClosed { + lm.logger.Printf("HTTP server error: %v", err) } - cancel() }() - lm.logger.Printf("Starting reverse proxy on %s", lm.options.HTTPAddr) - if err := lm.server.ListenAndServe(); err != http.ErrServerClosed { - return fmt.Errorf("error starting server: %v", err) + lm.logger.Printf("Server started on %s", lm.server.Addr) + + // Wait for interrupt signal + <-ctx.Done() + lm.logger.Println("Shutting down server...") + + // Create shutdown context with timeout + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Attempt graceful shutdown + if err := lm.server.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("server shutdown failed: %v", err) } + lm.logger.Println("Server gracefully stopped") return nil } diff --git a/main_test.go b/main_test.go index 7c04b13..3f12076 100644 --- a/main_test.go +++ b/main_test.go @@ -238,15 +238,23 @@ func TestNewLightyMux(t *testing.T) { { name: "Custom address", opts: &Options{ - HTTPAddr: ":9090", + HTTPAddr: ":8080", }, }, { name: "With logging options", opts: &Options{ + LogFile: "test.log", LogRequests: true, LogResponses: true, - LogErrors: true, + }, + }, + { + name: "Zero timeouts disable timeouts", + opts: &Options{ + ReadTimeout: 0, + WriteTimeout: 0, + IdleTimeout: 0, }, }, } @@ -258,8 +266,20 @@ func TestNewLightyMux(t *testing.T) { t.Errorf("NewLightyMux() error = %v, wantErr %v", err, tt.wantErr) return } - if lm == nil { - t.Error("NewLightyMux() returned nil LightyMux") + if err != nil { + return + } + + if tt.name == "Zero timeouts disable timeouts" { + if lm.server.ReadTimeout != 0 { + t.Errorf("ReadTimeout = %v, want 0", lm.server.ReadTimeout) + } + if lm.server.WriteTimeout != 0 { + t.Errorf("WriteTimeout = %v, want 0", lm.server.WriteTimeout) + } + if lm.server.IdleTimeout != 0 { + t.Errorf("IdleTimeout = %v, want 0", lm.server.IdleTimeout) + } } }) }