From e0d3c287a7e1c05b1e397f4727c447a1fcd9f7f6 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 29 Dec 2022 03:22:49 +0400 Subject: [PATCH] all: update on first run --- internal/home/controlupdate.go | 2 +- internal/home/home.go | 8 +---- internal/updater/updater.go | 60 ++++++++++++++++++-------------- internal/updater/updater_test.go | 8 ++--- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index ef4f06592e8..5718bfaa003 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -123,7 +123,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) { return } - err = Context.updater.Update() + err = Context.updater.Update(false) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) diff --git a/internal/home/home.go b/internal/home/home.go index b4788b0ceac..a3609e06214 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -1036,12 +1036,6 @@ func cmdlineUpdate(opts options) { err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}) fatalOnError(err) - if Context.firstRun { - log.Info("updates are not allowed on first run") - - os.Exit(0) - } - log.Info("cmdline update: performing update") updater := Context.updater @@ -1059,7 +1053,7 @@ func cmdlineUpdate(opts options) { os.Exit(0) } - err = updater.Update() + err = updater.Update(Context.firstRun) fatalOnError(err) os.Exit(0) diff --git a/internal/updater/updater.go b/internal/updater/updater.go index c84985ed04f..9bfd3748a52 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -105,48 +105,56 @@ func NewUpdater(conf *Config) *Updater { } // Update performs the auto-update. -func (u *Updater) Update() (err error) { +func (u *Updater) Update(firstRun bool) (err error) { u.mu.Lock() defer u.mu.Unlock() log.Info("updater: updating") - defer func() { log.Info("updater: finished; errors: %v", err) }() + defer func() { + if err != nil { + log.Error("updater: failed: %v", err) + } else { + log.Info("updater: finished") + } + }() execPath, err := os.Executable() if err != nil { - return err + return fmt.Errorf("getting executable path: %w", err) } err = u.prepare(execPath) if err != nil { - return err + return fmt.Errorf("preparing: %w", err) } defer u.clean() - err = u.downloadPackageFile(u.packageURL, u.packageName) + err = u.downloadPackageFile() if err != nil { - return err + return fmt.Errorf("downloading package file: %w", err) } err = u.unpack() if err != nil { - return err + return fmt.Errorf("unpacking: %w", err) } - err = u.check() - if err != nil { - return err + if !firstRun { + err = u.check() + if err != nil { + return fmt.Errorf("checking config: %w", err) + } } - err = u.backup() + err = u.backup(firstRun) if err != nil { - return err + return fmt.Errorf("making backup: %w", err) } err = u.replace() if err != nil { - return err + return fmt.Errorf("replacing: %w", err) } return nil @@ -230,31 +238,35 @@ func (u *Updater) unpack() error { func (u *Updater) check() error { log.Debug("updater: checking configuration") + err := copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml")) if err != nil { return fmt.Errorf("copyFile() failed: %w", err) } + cmd := exec.Command(u.updateExeName, "--check-config") err = cmd.Run() if err != nil || cmd.ProcessState.ExitCode() != 0 { return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) } + return nil } -func (u *Updater) backup() error { +func (u *Updater) backup(firstRun bool) (err error) { log.Debug("updater: backing up current configuration") _ = os.Mkdir(u.backupDir, 0o755) - err := copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) - if err != nil { - return fmt.Errorf("copyFile() failed: %w", err) + if !firstRun { + err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %w", err) + } } wd := u.workDir err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir) if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - wd, u.backupDir, err) + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", wd, u.backupDir, err) } return nil @@ -297,9 +309,9 @@ func (u *Updater) clean() { const MaxPackageFileSize = 32 * 1024 * 1024 // Download package file and save it to disk -func (u *Updater) downloadPackageFile(url, filename string) (err error) { +func (u *Updater) downloadPackageFile() (err error) { var resp *http.Response - resp, err = u.client.Get(url) + resp, err = u.client.Get(u.packageURL) if err != nil { return fmt.Errorf("http request failed: %w", err) } @@ -321,7 +333,7 @@ func (u *Updater) downloadPackageFile(url, filename string) (err error) { _ = os.Mkdir(u.updateDir, 0o755) log.Debug("updater: saving package to file") - err = os.WriteFile(filename, body, 0o644) + err = os.WriteFile(u.packageName, body, 0o644) if err != nil { return fmt.Errorf("os.WriteFile() failed: %w", err) } @@ -504,10 +516,6 @@ func zipFileUnpack(zipfile, outDir string) (files []string, err error) { // Copy file on disk func copyFile(src, dst string) error { - if src == "" || src == dst { - return nil - } - d, e := os.ReadFile(src) if e != nil { return e diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index dbf0e069d93..af9093ccbe4 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -136,10 +136,10 @@ func TestUpdate(t *testing.T) { u.packageURL = fakeURL.String() require.NoError(t, u.prepare(exePath)) - require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName)) + require.NoError(t, u.downloadPackageFile()) require.NoError(t, u.unpack()) // require.NoError(t, u.check()) - require.NoError(t, u.backup()) + require.NoError(t, u.backup(false)) require.NoError(t, u.replace()) u.clean() @@ -215,10 +215,10 @@ func TestUpdateWindows(t *testing.T) { u.packageURL = fakeURL.String() require.NoError(t, u.prepare(exePath)) - require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName)) + require.NoError(t, u.downloadPackageFile()) require.NoError(t, u.unpack()) // assert.Nil(t, u.check()) - require.NoError(t, u.backup()) + require.NoError(t, u.backup(false)) require.NoError(t, u.replace()) u.clean()