Skip to content

Commit

Permalink
replace liveBackupGetter type with function variable
Browse files Browse the repository at this point in the history
Signed-off-by: Steve Kriss <steve@heptio.com>
  • Loading branch information
skriss committed Jul 28, 2018
1 parent 5802afe commit 1f36ed4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 102 deletions.
22 changes: 2 additions & 20 deletions pkg/cloudprovider/backup_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ type BackupLister interface {
ListBackups(bucket string) ([]*api.Backup, error)
}

type BackupGetter interface {
// GetBackup retrieves a backup from object storage.
GetBackup(bucket, backupName string) (*api.Backup, error)
}

const (
metadataFileFormatString = "%s/ark-backup.json"
backupFileFormatString = "%s/%s.tar.gz"
Expand Down Expand Up @@ -192,21 +187,8 @@ func ListBackups(logger logrus.FieldLogger, objectStore ObjectStore, bucket stri
return output, nil
}

type liveBackupGetter struct {
logger logrus.FieldLogger
objectStore ObjectStore
}

func NewLiveBackupGetter(logger logrus.FieldLogger, objectStore ObjectStore) BackupGetter {
return &liveBackupGetter{
logger: logger,
objectStore: objectStore,
}
}

func (l *liveBackupGetter) GetBackup(bucket, backupName string) (*api.Backup, error) {
return GetBackup(l.objectStore, bucket, backupName)
}
//GetBackupFunc is a function that can retrieve backup metadata from an object store
type GetBackupFunc func(objectStore ObjectStore, bucket, backupName string) (*api.Backup, error)

// GetBackup gets the specified api.Backup from the given bucket in object storage.
func GetBackup(objectStore ObjectStore, bucket, backupName string) (*api.Backup, error) {
Expand Down
48 changes: 0 additions & 48 deletions pkg/cloudprovider/mocks/backup_getter.go

This file was deleted.

26 changes: 11 additions & 15 deletions pkg/controller/restore_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ type restoreController struct {
pluginRegistry plugin.Registry
metrics *metrics.ServerMetrics

getBackup cloudprovider.GetBackupFunc
downloadBackup cloudprovider.DownloadBackupFunc
uploadRestoreLog cloudprovider.UploadRestoreLogFunc
uploadRestoreResults cloudprovider.UploadRestoreResultsFunc
newPluginManager func(logger logrus.FieldLogger, logLevel logrus.Level, pluginRegistry plugin.Registry) plugin.Manager
newBackupGetter func(logger logrus.FieldLogger, objectStore cloudprovider.ObjectStore) cloudprovider.BackupGetter
}

func NewRestoreController(
Expand Down Expand Up @@ -130,14 +130,13 @@ func NewRestoreController(
pluginRegistry: pluginRegistry,
metrics: metrics,

downloadBackup: cloudprovider.DownloadBackup,

getBackup: cloudprovider.GetBackup,
downloadBackup: cloudprovider.DownloadBackup,
uploadRestoreLog: cloudprovider.UploadRestoreLog,
uploadRestoreResults: cloudprovider.UploadRestoreResults,
newPluginManager: func(logger logrus.FieldLogger, logLevel logrus.Level, pluginRegistry plugin.Registry) plugin.Manager {
return plugin.NewManager(logger, logLevel, pluginRegistry)
},
newBackupGetter: func(logger logrus.FieldLogger, objectStore cloudprovider.ObjectStore) cloudprovider.BackupGetter {
return cloudprovider.NewLiveBackupGetter(logger, objectStore)
},
}

c.syncHandler = c.processRestore
Expand Down Expand Up @@ -288,15 +287,14 @@ func (c *restoreController) processRestore(key string) error {
if err != nil {
return errors.Wrap(err, "error initializing object store")
}
backupGetter := c.newBackupGetter(logContext, objectStore)

actions, err := pluginManager.GetRestoreItemActions()
if err != nil {
return errors.Wrap(err, "error initializing restore item actions")
}

// complete & validate restore
if restore.Status.ValidationErrors = c.completeAndValidate(backupGetter, restore); len(restore.Status.ValidationErrors) > 0 {
if restore.Status.ValidationErrors = c.completeAndValidate(objectStore, restore); len(restore.Status.ValidationErrors) > 0 {
restore.Status.Phase = api.RestorePhaseFailedValidation
} else {
restore.Status.Phase = api.RestorePhaseInProgress
Expand Down Expand Up @@ -325,7 +323,6 @@ func (c *restoreController) processRestore(key string) error {
restore,
actions,
objectStore,
backupGetter,
)

restore.Status.Warnings = len(restoreWarnings.Ark) + len(restoreWarnings.Cluster)
Expand Down Expand Up @@ -358,7 +355,7 @@ func (c *restoreController) processRestore(key string) error {
return nil
}

func (c *restoreController) completeAndValidate(backupGetter cloudprovider.BackupGetter, restore *api.Restore) []string {
func (c *restoreController) completeAndValidate(objectStore cloudprovider.ObjectStore, restore *api.Restore) []string {
// add non-restorable resources to restore's excluded resources
excludedResources := sets.NewString(restore.Spec.ExcludedResources...)
for _, nonrestorable := range nonRestorableResources {
Expand Down Expand Up @@ -422,7 +419,7 @@ func (c *restoreController) completeAndValidate(backupGetter cloudprovider.Backu
backup *api.Backup
err error
)
if backup, err = c.fetchBackup(backupGetter, restore.Spec.BackupName); err != nil {
if backup, err = c.fetchBackup(objectStore, restore.Spec.BackupName); err != nil {
return append(validationErrors, fmt.Sprintf("Error retrieving backup: %v", err))
}

Expand Down Expand Up @@ -465,7 +462,7 @@ func mostRecentCompletedBackup(backups []*api.Backup) *api.Backup {
return nil
}

func (c *restoreController) fetchBackup(backupGetter cloudprovider.BackupGetter, name string) (*api.Backup, error) {
func (c *restoreController) fetchBackup(objectStore cloudprovider.ObjectStore, name string) (*api.Backup, error) {
backup, err := c.backupLister.Backups(c.namespace).Get(name)
if err == nil {
return backup, nil
Expand All @@ -478,7 +475,7 @@ func (c *restoreController) fetchBackup(backupGetter cloudprovider.BackupGetter,
logContext := c.logger.WithField("backupName", name)

logContext.Debug("Backup not found in backupLister, checking object storage directly")
backup, err = backupGetter.GetBackup(c.bucket, name)
backup, err = c.getBackup(objectStore, c.bucket, name)
if err != nil {
return nil, err
}
Expand All @@ -502,7 +499,6 @@ func (c *restoreController) runRestore(
restore *api.Restore,
actions []restore.ItemAction,
objectStore cloudprovider.ObjectStore,
backupGetter cloudprovider.BackupGetter,
) (restoreWarnings, restoreErrors api.RestoreResult, restoreFailure error) {
logFile, err := ioutil.TempFile("", "")
if err != nil {
Expand Down Expand Up @@ -534,7 +530,7 @@ func (c *restoreController) runRestore(
"backup": restore.Spec.BackupName,
})

backup, err := c.fetchBackup(backupGetter, restore.Spec.BackupName)
backup, err := c.fetchBackup(objectStore, restore.Spec.BackupName)
if err != nil {
logContext.WithError(err).Error("Error getting backup")
restoreErrors.Ark = append(restoreErrors.Ark, err.Error())
Expand Down
30 changes: 11 additions & 19 deletions pkg/controller/restore_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (

api "github.com/heptio/ark/pkg/apis/ark/v1"
"github.com/heptio/ark/pkg/cloudprovider"
cloudprovidermocks "github.com/heptio/ark/pkg/cloudprovider/mocks"
"github.com/heptio/ark/pkg/generated/clientset/versioned/fake"
informers "github.com/heptio/ark/pkg/generated/informers/externalversions"
"github.com/heptio/ark/pkg/metrics"
Expand Down Expand Up @@ -85,9 +84,7 @@ func TestFetchBackup(t *testing.T) {
restorer = &fakeRestorer{}
sharedInformers = informers.NewSharedInformerFactory(client, 0)
logger = arktest.NewLogger()
backupGetter = &cloudprovidermocks.BackupGetter{}
)
defer backupGetter.AssertExpectations(t)

c := NewRestoreController(
api.DefaultNamespace,
Expand All @@ -104,19 +101,20 @@ func TestFetchBackup(t *testing.T) {
nil, //pluginRegistry
metrics.NewServerMetrics(),
).(*restoreController)
c.newBackupGetter = func(logger logrus.FieldLogger, objectStore cloudprovider.ObjectStore) cloudprovider.BackupGetter {
return backupGetter
}

for _, itm := range test.informerBackups {
sharedInformers.Ark().V1().Backups().Informer().GetStore().Add(itm)
}

if test.backupServiceBackup != nil || test.backupServiceError != nil {
backupGetter.On("GetBackup", "bucket", test.backupName).Return(test.backupServiceBackup, test.backupServiceError)
c.getBackup = func(_ cloudprovider.ObjectStore, bucket, backup string) (*api.Backup, error) {
require.Equal(t, "bucket", bucket)
require.Equal(t, test.backupName, backup)
return test.backupServiceBackup, test.backupServiceError
}
}

backup, err := c.fetchBackup(backupGetter, test.backupName)
backup, err := c.fetchBackup(nil, test.backupName)

if assert.Equal(t, test.expectedErr, err != nil) {
assert.Equal(t, test.expectedRes, backup)
Expand Down Expand Up @@ -166,11 +164,9 @@ func TestProcessRestoreSkips(t *testing.T) {
sharedInformers = informers.NewSharedInformerFactory(client, 0)
logger = arktest.NewLogger()
pluginManager = &pluginmocks.Manager{}
backupGetter = &cloudprovidermocks.BackupGetter{}
objectStore = &arktest.ObjectStore{}
)
defer restorer.AssertExpectations(t)
defer backupGetter.AssertExpectations(t)
defer objectStore.AssertExpectations(t)

c := NewRestoreController(
Expand All @@ -191,9 +187,6 @@ func TestProcessRestoreSkips(t *testing.T) {
c.newPluginManager = func(logger logrus.FieldLogger, logLevel logrus.Level, pluginRegistry plugin.Registry) plugin.Manager {
return pluginManager
}
c.newBackupGetter = func(logger logrus.FieldLogger, objectStore cloudprovider.ObjectStore) cloudprovider.BackupGetter {
return backupGetter
}

if test.restore != nil {
sharedInformers.Ark().V1().Restores().Informer().GetStore().Add(test.restore)
Expand Down Expand Up @@ -381,11 +374,9 @@ func TestProcessRestore(t *testing.T) {
sharedInformers = informers.NewSharedInformerFactory(client, 0)
logger = arktest.NewLogger()
pluginManager = &pluginmocks.Manager{}
backupGetter = &cloudprovidermocks.BackupGetter{}
objectStore = &arktest.ObjectStore{}
)
defer restorer.AssertExpectations(t)
defer backupGetter.AssertExpectations(t)
defer objectStore.AssertExpectations(t)

c := NewRestoreController(
Expand All @@ -406,9 +397,6 @@ func TestProcessRestore(t *testing.T) {
c.newPluginManager = func(logger logrus.FieldLogger, logLevel logrus.Level, pluginRegistry plugin.Registry) plugin.Manager {
return pluginManager
}
c.newBackupGetter = func(logger logrus.FieldLogger, objectStore cloudprovider.ObjectStore) cloudprovider.BackupGetter {
return backupGetter
}

if test.restore != nil {
pluginManager.On("GetObjectStore", "myCloud").Return(objectStore, nil)
Expand Down Expand Up @@ -495,7 +483,11 @@ func TestProcessRestore(t *testing.T) {
}

if test.backupServiceGetBackupError != nil {
backupGetter.On("GetBackup", "bucket", test.restore.Spec.BackupName).Return(nil, test.backupServiceGetBackupError)
c.getBackup = func(_ cloudprovider.ObjectStore, bucket, backup string) (*api.Backup, error) {
require.Equal(t, "bucket", bucket)
require.Equal(t, test.restore.Spec.BackupName, backup)
return nil, test.backupServiceGetBackupError
}
}

if test.backupServiceDownloadBackupError != nil {
Expand Down

0 comments on commit 1f36ed4

Please sign in to comment.