Skip to content

Commit

Permalink
Prevents /mnt/models/<model name> from being converted into a file (k…
Browse files Browse the repository at this point in the history
…ubeflow#1549)

* prevents file directory from becoming a file

* updated gcs test file to be within the model directory
  • Loading branch information
abchoo authored Apr 23, 2021
1 parent 5a7e06a commit 57b548d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
6 changes: 5 additions & 1 deletion pkg/agent/storage/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func (g *GCSObjectDownloader) Download(client stiface.Client, it stiface.ObjectI
var errs []error
// flag to help determine if query prefix returned an empty iterator
var foundObject = false
filePath := filepath.Join(g.ModelDir, g.ModelName)
for {
attrs, err := it.Next()
if err == iterator.Done {
Expand All @@ -70,9 +71,12 @@ func (g *GCSObjectDownloader) Download(client stiface.Client, it stiface.ObjectI
if err != nil {
return fmt.Errorf("an error occurred while iterating: %v", err)
}
foundObject = true
objectValue := strings.TrimPrefix(attrs.Name, g.Item)
fileName := filepath.Join(g.ModelDir, g.ModelName, objectValue)
if fileName == filePath {
continue
}
foundObject = true
if FileExists(fileName) {
log.Info("Deleting", fileName)
if err := os.Remove(fileName); err != nil {
Expand Down
12 changes: 12 additions & 0 deletions pkg/agent/storage/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,15 @@ func (s *S3ObjectDownloader) GetAllObjects(s3Svc s3iface.S3API) ([]s3manager.Bat
return nil, fmt.Errorf("%s has no objects or does not exist", s.StorageUri)
}

var foundObject = false
filePath := filepath.Join(s.ModelDir, s.ModelName)

for _, object := range resp.Contents {
subObjectKey := strings.TrimPrefix(*object.Key, s.Prefix)
fileName := filepath.Join(s.ModelDir, s.ModelName, subObjectKey)
if fileName == filePath {
continue
}
if FileExists(fileName) {
// File got corrupted or is mid-download :(
// TODO: Figure out if we can maybe continue?
Expand All @@ -112,8 +118,14 @@ func (s *S3ObjectDownloader) GetAllObjects(s3Svc s3iface.S3API) ([]s3manager.Bat
return nil
},
}
foundObject = true
results = append(results, object)
}

if !foundObject {
return nil, fmt.Errorf("%s has no objects or does not exist", s.StorageUri)
}

return results, nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,11 @@ var _ = Describe("Watcher", func() {
Fail("Failed to write contents.")
}
modelName := "model1"
modelStorageURI := "gs://testBucket/testModel1"
modelStorageURI := "gs://testBucket/"
err := cl.DownloadModel(modelDir, modelName, modelStorageURI)
Expect(err).To(BeNil())

testFile := filepath.Join(modelDir, "model1")
testFile := filepath.Join(modelDir, modelName, "testModel1")
dat, err := ioutil.ReadFile(testFile)
Expect(err).To(BeNil())
Expect(string(dat)).To(Equal(modelContents))
Expand Down

0 comments on commit 57b548d

Please sign in to comment.