diff --git a/pkg/agent/storage/gcs.go b/pkg/agent/storage/gcs.go index df342a093fe..e5f91bd2585 100644 --- a/pkg/agent/storage/gcs.go +++ b/pkg/agent/storage/gcs.go @@ -62,6 +62,7 @@ func (g *GCSObjectDownloader) Download(client stiface.Client, it stiface.ObjectI var errs []error // flag to help determine if query prefix returned an empty iterator var foundObject = false + filePath := filepath.Join(g.ModelDir, g.ModelName) for { attrs, err := it.Next() if err == iterator.Done { @@ -70,9 +71,12 @@ func (g *GCSObjectDownloader) Download(client stiface.Client, it stiface.ObjectI if err != nil { return fmt.Errorf("an error occurred while iterating: %v", err) } - foundObject = true objectValue := strings.TrimPrefix(attrs.Name, g.Item) fileName := filepath.Join(g.ModelDir, g.ModelName, objectValue) + if fileName == filePath { + continue + } + foundObject = true if FileExists(fileName) { log.Info("Deleting", fileName) if err := os.Remove(fileName); err != nil { diff --git a/pkg/agent/storage/s3.go b/pkg/agent/storage/s3.go index a5537db6ca8..6edf69fff24 100644 --- a/pkg/agent/storage/s3.go +++ b/pkg/agent/storage/s3.go @@ -87,9 +87,15 @@ func (s *S3ObjectDownloader) GetAllObjects(s3Svc s3iface.S3API) ([]s3manager.Bat return nil, fmt.Errorf("%s has no objects or does not exist", s.StorageUri) } + var foundObject = false + filePath := filepath.Join(s.ModelDir, s.ModelName) + for _, object := range resp.Contents { subObjectKey := strings.TrimPrefix(*object.Key, s.Prefix) fileName := filepath.Join(s.ModelDir, s.ModelName, subObjectKey) + if fileName == filePath { + continue + } if FileExists(fileName) { // File got corrupted or is mid-download :( // TODO: Figure out if we can maybe continue? @@ -112,8 +118,14 @@ func (s *S3ObjectDownloader) GetAllObjects(s3Svc s3iface.S3API) ([]s3manager.Bat return nil }, } + foundObject = true results = append(results, object) } + + if !foundObject { + return nil, fmt.Errorf("%s has no objects or does not exist", s.StorageUri) + } + return results, nil } diff --git a/pkg/agent/watcher_test.go b/pkg/agent/watcher_test.go index e7b10b11eff..6a3f3273db2 100644 --- a/pkg/agent/watcher_test.go +++ b/pkg/agent/watcher_test.go @@ -351,11 +351,11 @@ var _ = Describe("Watcher", func() { Fail("Failed to write contents.") } modelName := "model1" - modelStorageURI := "gs://testBucket/testModel1" + modelStorageURI := "gs://testBucket/" err := cl.DownloadModel(modelDir, modelName, modelStorageURI) Expect(err).To(BeNil()) - testFile := filepath.Join(modelDir, "model1") + testFile := filepath.Join(modelDir, modelName, "testModel1") dat, err := ioutil.ReadFile(testFile) Expect(err).To(BeNil()) Expect(string(dat)).To(Equal(modelContents))