diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fd7a6e247b..367a9940aad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,7 @@ and this project adheres to - Improved HTTP requests handling and timeouts ([#2343]). - Our snap package now uses the `core20` image as its base ([#2306]). - New build system and various internal improvements ([#2271], [#2276], [#2297], - [#2509]). + [#2509], [#2552]). [#2231]: https://github.com/AdguardTeam/AdGuardHome/issues/2231 [#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271 @@ -63,6 +63,7 @@ and this project adheres to [#2391]: https://github.com/AdguardTeam/AdGuardHome/issues/2391 [#2394]: https://github.com/AdguardTeam/AdGuardHome/issues/2394 [#2509]: https://github.com/AdguardTeam/AdGuardHome/issues/2509 +[#2552]: https://github.com/AdguardTeam/AdGuardHome/issues/2552 [#2589]: https://github.com/AdguardTeam/AdGuardHome/issues/2589 ### Deprecated diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index da223ebd4ef..0ab24f60942 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -13,6 +13,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/AdguardTeam/AdGuardHome/internal/util" @@ -268,6 +269,9 @@ func copyInstallSettings(dst, src *configuration) { dst.DNS.Port = src.DNS.Port } +// shutdownTimeout is the timeout for shutting HTTP server down operation. +const shutdownTimeout = 5 * time.Second + // Apply new configuration, start DNS server, restart Web server func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { newSettings := applyConfigReq{} @@ -320,6 +324,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { config.DNS.BindHost = newSettings.DNS.IP config.DNS.Port = newSettings.DNS.Port + // TODO(e.burkov): StartMods() should be put in a separate goroutine at + // the moment we'll allow setting up TLS in the initial configuration or + // the configuration itself will use HTTPS protocol, because the + // underlying functions potentially restart the HTTPS server. err = StartMods() if err != nil { Context.firstRun = true @@ -351,16 +359,22 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { f.Flush() } - // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block - // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely + // The Shutdown() method of (*http.Server) needs to be called in a + // separate goroutine, because it waits until all requests are handled + // and will be blocked by it's own caller. if restartHTTP { - go func() { - _ = web.httpServer.Shutdown(context.TODO()) - }() + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + + shut := func(srv *http.Server) { + defer cancel() + err := srv.Shutdown(ctx) + if err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } + } + go shut(web.httpServer) if web.httpServerBeta != nil { - go func() { - _ = web.httpServerBeta.Shutdown(context.TODO()) - }() + go shut(web.httpServerBeta) } } } diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 327f2f908b1..8b3f1bf519f 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -1,6 +1,7 @@ package home import ( + "context" "encoding/json" "errors" "net/http" @@ -90,7 +91,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { } } -// Perform an update procedure to the latest available version +// handleUpdate performs an update to the latest available version procedure. func handleUpdate(w http.ResponseWriter, _ *http.Request) { if Context.updater.NewVersion() == "" { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") @@ -108,7 +109,13 @@ func handleUpdate(w http.ResponseWriter, _ *http.Request) { f.Flush() } - go finishUpdate() + // The background context is used because the underlying functions wrap + // it with timeout and shut down the server, which handles current + // request. It also should be done in a separate goroutine due to the + // same reason. + go func() { + finishUpdate(context.Background()) + }() } // versionResponse is the response for /control/version.json endpoint. @@ -140,10 +147,10 @@ func (vr *versionResponse) confirmAutoUpdate() { } } -// Complete an update procedure -func finishUpdate() { +// finishUpdate completes an update procedure. +func finishUpdate(ctx context.Context) { log.Info("Stopping all tasks") - cleanup() + cleanup(ctx) cleanupAlways() exeName := "AdGuardHome" diff --git a/internal/home/controlupdate_test.go b/internal/home/controlupdate_test.go deleted file mode 100644 index 45112f50c25..00000000000 --- a/internal/home/controlupdate_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// +build ignore - -package home - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDoUpdate(t *testing.T) { - config.DNS.Port = 0 - Context.workDir = "..." // set absolute path - newver := "v0.96" - - data := `{ - "version": "v0.96", - "announcement": "AdGuard Home v0.96 is now available!", - "announcement_url": "", - "download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip", - "download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip", - "download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip", - "download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip", - "download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz", - "download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz", - "download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz", - "download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz", - "download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz", - "download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz", - "download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz", - "download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz", - "download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz", - "download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz", - "download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz", - "download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz", - "download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz", - "download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz", - "download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz", - "download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz", - "download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz", - "download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz" - }` - uu, err := getUpdateInfo([]byte(data)) - if err != nil { - t.Fatalf("getUpdateInfo: %s", err) - } - - u := updateInfo{ - pkgURL: "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz", - pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", - newVer: newver, - updateDir: Context.workDir + "/agh-update-" + newver, - backupDir: Context.workDir + "/agh-backup", - configName: Context.workDir + "/AdGuardHome.yaml", - updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome.yaml", - curBinName: Context.workDir + "/AdGuardHome", - bkpBinName: Context.workDir + "/agh-backup/AdGuardHome", - newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome", - } - - assert.Equal(t, uu.pkgURL, u.pkgURL) - assert.Equal(t, uu.pkgName, u.pkgName) - assert.Equal(t, uu.newVer, u.newVer) - assert.Equal(t, uu.updateDir, u.updateDir) - assert.Equal(t, uu.backupDir, u.backupDir) - assert.Equal(t, uu.configName, u.configName) - assert.Equal(t, uu.updateConfigName, u.updateConfigName) - assert.Equal(t, uu.curBinName, u.curBinName) - assert.Equal(t, uu.bkpBinName, u.bkpBinName) - assert.Equal(t, uu.newBinName, u.newBinName) - - e := doUpdate(&u) - if e != nil { - t.Fatalf("FAILED: %s", e) - } - os.RemoveAll(u.backupDir) -} - -func TestTargzFileUnpack(t *testing.T) { - fn := "../dist/AdGuardHome_linux_amd64.tar.gz" - outdir := "../test-unpack" - defer os.RemoveAll(outdir) - _ = os.Mkdir(outdir, 0o755) - files, e := targzFileUnpack(fn, outdir) - if e != nil { - t.Fatalf("FAILED: %s", e) - } - t.Logf("%v", files) -} - -func TestZipFileUnpack(t *testing.T) { - fn := "../dist/AdGuardHome_windows_amd64.zip" - outdir := "../test-unpack" - _ = os.Mkdir(outdir, 0o755) - files, e := zipFileUnpack(fn, outdir) - if e != nil { - t.Fatalf("FAILED: %s", e) - } - t.Logf("%v", files) - os.RemoveAll(outdir) -} diff --git a/internal/home/home.go b/internal/home/home.go index 939d501fdda..1b6312c8c60 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -109,7 +109,7 @@ func Main() { Context.tls.Reload() default: - cleanup() + cleanup(context.Background()) cleanupAlways() os.Exit(0) } @@ -335,7 +335,7 @@ func run(args options) { select {} } -// StartMods - initialize and start DNS after installation +// StartMods initializes and starts the DNS server after installation. func StartMods() error { err := initDNSServer() if err != nil { @@ -506,11 +506,12 @@ func configureLogger(args options) { } } -func cleanup() { +// cleanup stops and resets all the modules. +func cleanup(ctx context.Context) { log.Info("Stopping AdGuard Home") if Context.web != nil { - Context.web.Close() + Context.web.Close(ctx) Context.web = nil } if Context.auth != nil { diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 344657e9efa..5f3f3d81462 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -186,6 +186,6 @@ func TestHome(t *testing.T) { time.Sleep(1 * time.Second) } - cleanup() + cleanup(context.Background()) cleanupAlways() } diff --git a/internal/home/tls.go b/internal/home/tls.go index de6dcd186b4..1f68cc3950f 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -1,6 +1,7 @@ package home import ( + "context" "crypto" "crypto/ecdsa" "crypto/rsa" @@ -92,7 +93,7 @@ func (t *TLSMod) setCertFileTime() { t.certLastMod = fi.ModTime().UTC() } -// Start - start the module +// Start updates the configuration of TLSMod and starts it. func (t *TLSMod) Start() { if !tlsWebHandlersRegistered { tlsWebHandlersRegistered = true @@ -102,10 +103,14 @@ func (t *TLSMod) Start() { t.confLock.Lock() tlsConf := t.conf t.confLock.Unlock() - Context.web.TLSConfigChanged(tlsConf) + + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. + Context.web.TLSConfigChanged(context.Background(), tlsConf) } -// Reload - reload certificate file +// Reload updates the configuration of TLSMod and restarts it. func (t *TLSMod) Reload() { t.confLock.Lock() tlsConf := t.conf @@ -139,7 +144,10 @@ func (t *TLSMod) Reload() { t.confLock.Lock() tlsConf = t.conf t.confLock.Unlock() - Context.web.TLSConfigChanged(tlsConf) + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. + Context.web.TLSConfigChanged(context.Background(), tlsConf) } // Set certificate and private key data @@ -296,11 +304,13 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { f.Flush() } - // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block - // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely + // The background context is used because the TLSConfigChanged wraps + // context with timeout on its own and shuts down the server, which + // handles current request. It is also should be done in a separate + // goroutine due to the same reason. if restartHTTPS { go func() { - Context.web.TLSConfigChanged(data) + Context.web.TLSConfigChanged(context.Background(), data) }() } } diff --git a/internal/home/web.go b/internal/home/web.go index 72a2eb62198..0048f42755a 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -122,8 +122,9 @@ func WebCheckPortAvailable(port int) bool { return true } -// TLSConfigChanged - called when TLS configuration has changed -func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { +// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server +// if necessary. +func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { log.Debug("Web: applying new TLS configuration") web.conf.PortHTTPS = tlsConf.PortHTTPS web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0) @@ -143,7 +144,12 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { web.httpsServer.cond.L.Lock() if web.httpsServer.server != nil { - _ = web.httpsServer.server.Shutdown(context.TODO()) + ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) + err = web.httpsServer.server.Shutdown(ctx) + cancel() + if err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } } web.httpsServer.enabled = enabled web.httpsServer.cert = cert @@ -198,22 +204,28 @@ func (web *Web) Start() { } } -// Close - stop HTTP server, possibly waiting for all active connections to be closed -func (web *Web) Close() { +// Close gracefully shuts down the HTTP servers. +func (web *Web) Close(ctx context.Context) { log.Info("Stopping HTTP server...") web.httpsServer.cond.L.Lock() web.httpsServer.shutdown = true web.httpsServer.cond.L.Unlock() - if web.httpsServer.server != nil { - _ = web.httpsServer.server.Shutdown(context.TODO()) - } - if web.httpServer != nil { - _ = web.httpServer.Shutdown(context.TODO()) - } - if web.httpServerBeta != nil { - _ = web.httpServerBeta.Shutdown(context.TODO()) + + shut := func(srv *http.Server) { + if srv == nil { + return + } + ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Debug("error while shutting down HTTP server: %s", err) + } } + shut(web.httpsServer.server) + shut(web.httpServer) + shut(web.httpServerBeta) + log.Info("Stopped HTTP server") } diff --git a/internal/home/whois.go b/internal/home/whois.go index dcdeea9a8d7..26c674dceb1 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -36,19 +36,20 @@ type Whois struct { timeoutMsec uint } -// Create module context +// initWhois creates the Whois module context. func initWhois(clients *clientsContainer) *Whois { - w := Whois{} - w.timeoutMsec = 5000 - w.clients = clients - - cconf := cache.Config{} - cconf.EnableLRU = true - cconf.MaxCount = 10000 - w.ipAddrs = cache.New(cconf) + w := Whois{ + timeoutMsec: 5000, + clients: clients, + ipAddrs: cache.New(cache.Config{ + EnableLRU: true, + MaxCount: 10000, + }), + ipChan: make(chan net.IP, 255), + } - w.ipChan = make(chan net.IP, 255) go w.workerLoop() + return &w } @@ -121,12 +122,12 @@ func whoisParse(data string) map[string]string { const MaxConnReadSize = 64 * 1024 // Send request to a server and receive the response -func (w *Whois) query(target, serverAddr string) (string, error) { +func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) { addr, _, _ := net.SplitHostPort(serverAddr) if addr == "whois.arin.net" { target = "n + " + target } - conn, err := customDialContext(context.TODO(), "tcp", serverAddr) + conn, err := customDialContext(ctx, "tcp", serverAddr) if err != nil { return "", err } @@ -154,11 +155,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) { } // Query WHOIS servers (handle redirects) -func (w *Whois) queryAll(target string) (string, error) { +func (w *Whois) queryAll(ctx context.Context, target string) (string, error) { server := net.JoinHostPort(defaultServer, defaultPort) const maxRedirects = 5 for i := 0; i != maxRedirects; i++ { - resp, err := w.query(target, server) + resp, err := w.query(ctx, target, server) if err != nil { return "", err } @@ -184,9 +185,9 @@ func (w *Whois) queryAll(target string) (string, error) { } // Request WHOIS information -func (w *Whois) process(ip net.IP) [][]string { +func (w *Whois) process(ctx context.Context, ip net.IP) [][]string { data := [][]string{} - resp, err := w.queryAll(ip.String()) + resp, err := w.queryAll(ctx, ip.String()) if err != nil { log.Debug("Whois: error: %s IP:%s", err, ip) return data @@ -233,12 +234,13 @@ func (w *Whois) Begin(ip net.IP) { } } -// Get IP address from channel; get WHOIS info; associate info with a client +// workerLoop processes the IP addresses it got from the channel and associates +// the retrieving WHOIS info with a client. func (w *Whois) workerLoop() { for { ip := <-w.ipChan - info := w.process(ip) + info := w.process(context.Background(), ip) if len(info) == 0 { continue } diff --git a/internal/home/whois_test.go b/internal/home/whois_test.go index f8109e9268d..a160cdda489 100644 --- a/internal/home/whois_test.go +++ b/internal/home/whois_test.go @@ -1,6 +1,7 @@ package home import ( + "context" "testing" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -19,7 +20,7 @@ func TestWhois(t *testing.T) { assert.Nil(t, prepareTestDNSServer()) w := Whois{timeoutMsec: 5000} - resp, err := w.queryAll("8.8.8.8") + resp, err := w.queryAll(context.Background(), "8.8.8.8") assert.Nil(t, err) m := whoisParse(resp) assert.Equal(t, "Google LLC", m["orgname"]) diff --git a/internal/sysutil/syslog_others.go b/internal/sysutil/syslog_others.go index 0e0e1c3f0c0..91fe27cf569 100644 --- a/internal/sysutil/syslog_others.go +++ b/internal/sysutil/syslog_others.go @@ -1,4 +1,4 @@ -// +build !windows,!nacl,!plan9 +// +build !windows,!plan9 package sysutil diff --git a/internal/sysutil/syslog_windows.go b/internal/sysutil/syslog_windows.go index 2160ea43930..1d5cd03f73c 100644 --- a/internal/sysutil/syslog_windows.go +++ b/internal/sysutil/syslog_windows.go @@ -1,4 +1,4 @@ -// +build windows nacl plan9 +// +build windows plan9 package sysutil