diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 5d1cce4a71c..8999845b13d 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -530,14 +530,14 @@ func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) // prepareInternalProxy initializes the DNS proxy that is used for internal DNS // queries, such as public clients PTR resolving and updater hostname resolving. func (s *Server) prepareInternalProxy() (err error) { + srvConf := s.conf conf := &proxy.Config{ CacheEnabled: true, CacheSizeBytes: 4096, - UpstreamConfig: s.conf.UpstreamConfig, + UpstreamConfig: srvConf.UpstreamConfig, MaxGoroutines: int(s.conf.MaxGoroutines), } - srvConf := s.conf setProxyUpstreamMode( conf, srvConf.AllServers, diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index ef4f06592e8..5718bfaa003 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -123,7 +123,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) { return } - err = Context.updater.Update() + err = Context.updater.Update(false) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) diff --git a/internal/home/dns.go b/internal/home/dns.go index 1980b252392..9d073d7b2e6 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -9,7 +9,9 @@ import ( "path/filepath" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -39,17 +41,13 @@ func onConfigModified() { } } -// initDNSServer creates an instance of the dnsforward.Server -// Please note that we must do it even if we don't start it -// so that we had access to the query log and the stats -func initDNSServer() (err error) { +// initDNS updates all the fields of the [Context] needed to initialize the DNS +// server and initializes it at last. It also must not be called unless +// [config] and [Context] are initialized. +func initDNS() (err error) { baseDir := Context.getDataDir() - var anonFunc aghnet.IPMutFunc - if config.DNS.AnonymizeClientIP { - anonFunc = querylog.AnonymizeIP - } - anonymizer := aghnet.NewIPMut(anonFunc) + anonymizer := config.anonymizer() statsConf := stats.Config{ Filename: filepath.Join(baseDir, "stats.db"), @@ -82,34 +80,46 @@ func initDNSServer() (err error) { return err } - var privateNets netutil.SubnetSet - switch len(config.DNS.PrivateNets) { - case 0: - // Use an optimized locally-served matcher. - privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) - case 1: - privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) - if err != nil { - return fmt.Errorf("preparing the set of private subnets: %w", err) - } - default: - var nets []*net.IPNet - nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) - if err != nil { - return fmt.Errorf("preparing the set of private subnets: %w", err) - } + tlsConf := &tlsConfigSettings{} + Context.tls.WriteDiskConfig(tlsConf) + + return initDNSServer( + Context.filters, + Context.stats, + Context.queryLog, + Context.dhcpServer, + anonymizer, + httpRegister, + tlsConf, + ) +} - privateNets = netutil.SliceSubnetSet(nets) +// initDNSServer initializes the [context.dnsServer]. To only use the internal +// proxy, none of the arguments are required, but tlsConf still must not be nil, +// in other cases all the arguments also must not be nil. It also must not be +// called unless [config] and [Context] are initialized. +func initDNSServer( + filters *filtering.DNSFilter, + sts stats.Interface, + qlog querylog.QueryLog, + dhcpSrv dhcpd.Interface, + anonymizer *aghnet.IPMut, + httpReg aghhttp.RegisterFunc, + tlsConf *tlsConfigSettings, +) (err error) { + privateNets, err := parseSubnetSet(config.DNS.PrivateNets) + if err != nil { + return fmt.Errorf("preparing set of private subnets: %w", err) } p := dnsforward.DNSCreateParams{ - DNSFilter: Context.filters, - Stats: Context.stats, - QueryLog: Context.queryLog, + DNSFilter: filters, + Stats: sts, + QueryLog: qlog, PrivateNets: privateNets, Anonymizer: anonymizer, LocalDomain: config.DHCP.LocalDomainName, - DHCPServer: Context.dhcpServer, + DHCPServer: dhcpSrv, } Context.dnsServer, err = dnsforward.NewServer(p) @@ -120,15 +130,15 @@ func initDNSServer() (err error) { } Context.clients.dnsServer = Context.dnsServer - var dnsConfig dnsforward.ServerConfig - dnsConfig, err = generateServerConfig() + + dnsConf, err := generateServerConfig(tlsConf, httpReg) if err != nil { closeDNSServer() return fmt.Errorf("generateServerConfig: %w", err) } - err = Context.dnsServer.Prepare(&dnsConfig) + err = Context.dnsServer.Prepare(&dnsConf) if err != nil { closeDNSServer() @@ -146,6 +156,32 @@ func initDNSServer() (err error) { return nil } +// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns +// a subnet set that matches all locally served networks, see +// [netutil.IsLocallyServed]. +func parseSubnetSet(nets []string) (s netutil.SubnetSet, err error) { + switch len(nets) { + case 0: + // Use an optimized function-based matcher. + return netutil.SubnetSetFunc(netutil.IsLocallyServed), nil + case 1: + s, err = netutil.ParseSubnet(nets[0]) + if err != nil { + return nil, err + } + + return s, nil + default: + var nets []*net.IPNet + nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) + if err != nil { + return nil, err + } + + return netutil.SliceSubnetSet(nets), nil + } +} + func isRunning() bool { return Context.dnsServer != nil && Context.dnsServer.IsRunning() } @@ -193,7 +229,10 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) { return udpAddrs } -func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { +func generateServerConfig( + tlsConf *tlsConfigSettings, + httpReg aghhttp.RegisterFunc, +) (newConf dnsforward.ServerConfig, err error) { dnsConf := config.DNS hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) newConf = dnsforward.ServerConfig{ @@ -201,12 +240,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port), FilteringConfig: dnsConf.FilteringConfig, ConfigModified: onConfigModified, - HTTPRegister: httpRegister, + HTTPRegister: httpReg, OnDNSRequest: onDNSRequest, } - tlsConf := tlsConfigSettings{} - Context.tls.WriteDiskConfig(&tlsConf) if tlsConf.Enabled { newConf.TLSConfig = tlsConf.TLSConfig newConf.TLSConfig.ServerName = tlsConf.ServerName @@ -224,7 +261,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { } if tlsConf.PortDNSCrypt != 0 { - newConf.DNSCryptConfig, err = newDNSCrypt(hosts, tlsConf) + newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf) if err != nil { // Don't wrap the error, because it's already // wrapped by newDNSCrypt. @@ -413,7 +450,11 @@ func startDNSServer() error { func reconfigureDNSServer() (err error) { var newConf dnsforward.ServerConfig - newConf, err = generateServerConfig() + + tlsConf := &tlsConfigSettings{} + Context.tls.WriteDiskConfig(tlsConf) + + newConf, err = generateServerConfig(tlsConf, httpRegister) if err != nil { return fmt.Errorf("generating forwarding dns server config: %w", err) } diff --git a/internal/home/home.go b/internal/home/home.go index 3085d66c488..ce464060ca0 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -455,6 +455,10 @@ func run(opts options, clientBuildFS fs.FS) { err = setupConfig(opts) fatalOnError(err) + // TODO(e.burkov): This could be made earlier, probably as the option's + // effect. + cmdlineUpdate(opts) + if !Context.firstRun { // Save the updated config err = config.write() @@ -522,7 +526,7 @@ func run(opts options, clientBuildFS fs.FS) { fatalOnError(err) if !Context.firstRun { - err = initDNSServer() + err = initDNS() fatalOnError(err) Context.tls.start() @@ -543,20 +547,24 @@ func run(opts options, clientBuildFS fs.FS) { } } - // TODO(a.garipov): This could be made much earlier and could be done on - // the first run as well, but to achieve this we need to bypass requests - // over dnsforward resolver. - cmdlineUpdate(opts) - Context.web.Start() // wait indefinitely for other go-routines to complete their job select {} } +func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) { + var anonFunc aghnet.IPMutFunc + if c.DNS.AnonymizeClientIP { + anonFunc = querylog.AnonymizeIP + } + + return aghnet.NewIPMut(anonFunc) +} + // startMods initializes and starts the DNS server after installation. -func startMods() error { - err := initDNSServer() +func startMods() (err error) { + err = initDNS() if err != nil { return err } @@ -927,8 +935,8 @@ func getHTTPProxy(_ *http.Request) (*url.URL, error) { // jsonError is a generic JSON error response. // -// TODO(a.garipov): Merge together with the implementations in .../dhcpd and -// other packages after refactoring the web handler registering. +// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other +// packages after refactoring the web handler registering. type jsonError struct { // Message is the error message, an opaque string. Message string `json:"message"` @@ -940,30 +948,40 @@ func cmdlineUpdate(opts options) { return } - log.Info("starting update") - - if Context.firstRun { - log.Info("update not allowed on first run") + // Initialize the DNS server to use the internal resolver which the updater + // needs to be able to resolve the update source hostname. + // + // TODO(e.burkov): We could probably initialize the internal resolver + // separately. + err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}) + fatalOnError(err) - os.Exit(0) - } + log.Info("cmdline update: performing update") - _, err := Context.updater.VersionInfo(true) + updater := Context.updater + info, err := updater.VersionInfo(true) if err != nil { - vcu := Context.updater.VersionCheckURL() + vcu := updater.VersionCheckURL() log.Error("getting version info from %s: %s", vcu, err) - os.Exit(0) + os.Exit(1) } - if Context.updater.NewVersion() == "" { + if info.NewVersion == version.Version() { log.Info("no updates available") os.Exit(0) } - err = Context.updater.Update() + err = updater.Update(Context.firstRun) fatalOnError(err) + err = restartService() + if err != nil { + log.Debug("restarting service: %s", err) + log.Info("AdGuard Home was not installed as a service. " + + "Please restart running instances of AdGuardHome manually.") + } + os.Exit(0) } diff --git a/internal/home/options.go b/internal/home/options.go index befc78a25dc..9d435b6ed66 100644 --- a/internal/home/options.go +++ b/internal/home/options.go @@ -229,7 +229,7 @@ var cmdLineOpts = []cmdLineOpt{{ updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil }, effect: nil, serialize: func(o options) (val string, ok bool) { return "", o.performUpdate }, - description: "Update application and exit.", + description: "Update the current binary and restart the service in case it's installed.", longName: "update", shortName: "", }, { diff --git a/internal/home/service.go b/internal/home/service.go index 3aece1f2715..c0fe845f4f4 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -159,6 +159,38 @@ func sendSigReload() { log.Debug("service: sent signal to pid %d", pid) } +// restartService restarts the service. It returns error if the service is not +// running. +func restartService() (err error) { + // Call chooseSystem explicitly to introduce OpenBSD support for service + // package. It's a noop for other GOOS values. + chooseSystem() + + pwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("getting current directory: %w", err) + } + + svcConfig := &service.Config{ + Name: serviceName, + DisplayName: serviceDisplayName, + Description: serviceDescription, + WorkingDirectory: pwd, + } + configureService(svcConfig) + + var s service.Service + if s, err = service.New(&program{}, svcConfig); err != nil { + return fmt.Errorf("initializing service: %w", err) + } + + if err = svcAction(s, "restart"); err != nil { + return fmt.Errorf("restarting service: %w", err) + } + + return nil +} + // handleServiceControlAction one of the possible control actions: // // - install: Installs a service/daemon. diff --git a/internal/home/service_linux.go b/internal/home/service_linux.go index 39d572a0812..e5dd2953b32 100644 --- a/internal/home/service_linux.go +++ b/internal/home/service_linux.go @@ -7,6 +7,8 @@ import ( "github.com/kardianos/service" ) +// chooseSystem checks the current system detected and substitutes it with local +// implementation if needed. func chooseSystem() { sys := service.ChosenSystem() // By default, package service uses the SysV system if it cannot detect diff --git a/internal/home/service_openbsd.go b/internal/home/service_openbsd.go index 071775b97ac..4f94f0b4dcf 100644 --- a/internal/home/service_openbsd.go +++ b/internal/home/service_openbsd.go @@ -30,6 +30,8 @@ import ( // sysVersion is the version of local service.System interface implementation. const sysVersion = "openbsd-runcom" +// chooseSystem checks the current system detected and substitutes it with local +// implementation if needed. func chooseSystem() { service.ChooseSystem(openbsdSystem{}) } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 0ac8d9be837..b1df74c1b34 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -180,7 +180,7 @@ func withRecovered(orig *error) { // type check var _ Interface = (*StatsCtx)(nil) -// Start implements the Interface interface for *StatsCtx. +// Start implements the [Interface] interface for *StatsCtx. func (s *StatsCtx) Start() { s.initWeb() diff --git a/internal/updater/check.go b/internal/updater/check.go index 05a4b59c9ab..5de7ecfc4ba 100644 --- a/internal/updater/check.go +++ b/internal/updater/check.go @@ -61,7 +61,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) } - u.prevCheckTime = time.Now() + u.prevCheckTime = now u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body) return u.prevCheckResult, u.prevCheckError diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 3d89f7dd385..f042ab3c3d1 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -104,49 +104,58 @@ func NewUpdater(conf *Config) *Updater { } } -// Update performs the auto-update. -func (u *Updater) Update() (err error) { +// Update performs the auto-update. It returns an error if the update failed. +// If firstRun is true, it assumes the configuration file doesn't exist. +func (u *Updater) Update(firstRun bool) (err error) { u.mu.Lock() defer u.mu.Unlock() log.Info("updater: updating") - defer func() { log.Info("updater: finished; errors: %v", err) }() + defer func() { + if err != nil { + log.Error("updater: failed: %v", err) + } else { + log.Info("updater: finished") + } + }() execPath, err := os.Executable() if err != nil { - return err + return fmt.Errorf("getting executable path: %w", err) } err = u.prepare(execPath) if err != nil { - return err + return fmt.Errorf("preparing: %w", err) } defer u.clean() - err = u.downloadPackageFile(u.packageURL, u.packageName) + err = u.downloadPackageFile() if err != nil { - return err + return fmt.Errorf("downloading package file: %w", err) } err = u.unpack() if err != nil { - return err + return fmt.Errorf("unpacking: %w", err) } - err = u.check() - if err != nil { - return err + if !firstRun { + err = u.check() + if err != nil { + return fmt.Errorf("checking config: %w", err) + } } - err = u.backup() + err = u.backup(firstRun) if err != nil { - return err + return fmt.Errorf("making backup: %w", err) } err = u.replace() if err != nil { - return err + return fmt.Errorf("replacing: %w", err) } return nil @@ -204,6 +213,7 @@ func (u *Updater) prepare(exePath string) (err error) { return nil } +// unpack extracts the files from the downloaded archive. func (u *Updater) unpack() error { var err error _, pkgNameOnly := filepath.Split(u.packageURL) @@ -228,38 +238,48 @@ func (u *Updater) unpack() error { return nil } +// check returns an error if the configuration file couldn't be used with the +// version of AdGuard Home just downloaded. func (u *Updater) check() error { log.Debug("updater: checking configuration") + err := copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml")) if err != nil { return fmt.Errorf("copyFile() failed: %w", err) } + cmd := exec.Command(u.updateExeName, "--check-config") err = cmd.Run() if err != nil || cmd.ProcessState.ExitCode() != 0 { return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) } + return nil } -func (u *Updater) backup() error { +// backup makes a backup of the current configuration and supporting files. It +// ignores the configuration file if firstRun is true. +func (u *Updater) backup(firstRun bool) (err error) { log.Debug("updater: backing up current configuration") _ = os.Mkdir(u.backupDir, 0o755) - err := copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) - if err != nil { - return fmt.Errorf("copyFile() failed: %w", err) + if !firstRun { + err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %w", err) + } } wd := u.workDir err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir) if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - wd, u.backupDir, err) + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", wd, u.backupDir, err) } return nil } +// replace moves the current executable with the updated one and also copies the +// supporting files. func (u *Updater) replace() error { err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir) if err != nil { @@ -287,6 +307,7 @@ func (u *Updater) replace() error { return nil } +// clean removes the temporary directory itself and all it's contents. func (u *Updater) clean() { _ = os.RemoveAll(u.updateDir) } @@ -297,9 +318,9 @@ func (u *Updater) clean() { const MaxPackageFileSize = 32 * 1024 * 1024 // Download package file and save it to disk -func (u *Updater) downloadPackageFile(url, filename string) (err error) { +func (u *Updater) downloadPackageFile() (err error) { var resp *http.Response - resp, err = u.client.Get(url) + resp, err = u.client.Get(u.packageURL) if err != nil { return fmt.Errorf("http request failed: %w", err) } @@ -321,7 +342,7 @@ func (u *Updater) downloadPackageFile(url, filename string) (err error) { _ = os.Mkdir(u.updateDir, 0o755) log.Debug("updater: saving package to file") - err = os.WriteFile(filename, body, 0o644) + err = os.WriteFile(u.packageName, body, 0o644) if err != nil { return fmt.Errorf("os.WriteFile() failed: %w", err) } diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index dbf0e069d93..af9093ccbe4 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -136,10 +136,10 @@ func TestUpdate(t *testing.T) { u.packageURL = fakeURL.String() require.NoError(t, u.prepare(exePath)) - require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName)) + require.NoError(t, u.downloadPackageFile()) require.NoError(t, u.unpack()) // require.NoError(t, u.check()) - require.NoError(t, u.backup()) + require.NoError(t, u.backup(false)) require.NoError(t, u.replace()) u.clean() @@ -215,10 +215,10 @@ func TestUpdateWindows(t *testing.T) { u.packageURL = fakeURL.String() require.NoError(t, u.prepare(exePath)) - require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName)) + require.NoError(t, u.downloadPackageFile()) require.NoError(t, u.unpack()) // assert.Nil(t, u.check()) - require.NoError(t, u.backup()) + require.NoError(t, u.backup(false)) require.NoError(t, u.replace()) u.clean()