Skip to content

Commit

Permalink
Puller streamlining/simplification (kubeflow#1057)
Browse files Browse the repository at this point in the history
* Puller streamlining/simplification

Follow-on changes to kubeflow#989 based on remaining review suggestions.

- Simplified configmap change diffing
- Connect watcher and puller with event channel
- Have puller track in-progress ops per model via op completion channel and tie lifecycle of per-model channel+goroutine pairs to this

* Minor change: fully decouple puller from watcher

* Address some of the review comments

The complete ModelOp struct is now passed all the way back and forth.
  • Loading branch information
njhill authored Sep 19, 2020
1 parent d883c34 commit 773b10f
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 189 deletions.
19 changes: 2 additions & 17 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,7 @@ func main() {
}
}

watcher := agent.Watcher{
ConfigDir: *configDir,
ModelTracker: map[string]agent.ModelWrapper{},
Puller: agent.Puller{
ChannelMap: map[string]agent.Channel{},
Downloader: downloader,
},
}

syncer := agent.Syncer{
Watcher: watcher,
}

// Doing a forced sync in the case of container failures
// and for pre-filled config maps
syncer.Start()
watcher := agent.NewWatcher(*configDir, *modelDir)
agent.StartPuller(downloader, watcher.ModelEvents)
watcher.Start()

}
9 changes: 3 additions & 6 deletions pkg/agent/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1"
"log"
"path/filepath"
"regexp"
Expand All @@ -17,9 +18,7 @@ type Downloader struct {

var SupportedProtocols = []storage.Protocol{storage.S3}

func (d *Downloader) DownloadModel(event EventWrapper) error {
modelSpec := event.ModelSpec
modelName := event.ModelName
func (d *Downloader) DownloadModel(modelName string, modelSpec *v1beta1.ModelSpec) error {
if modelSpec != nil {
modelUri := modelSpec.StorageURI
hashModelUri := hash(modelUri)
Expand All @@ -29,7 +28,7 @@ func (d *Downloader) DownloadModel(event EventWrapper) error {
successFile := filepath.Join(d.ModelDir, modelName,
fmt.Sprintf("SUCCESS.%s.%s.%s", hashModelUri, hashFramework, hashMemory))
// Download if the event there is a success file and the event is one which we wish to Download
if !storage.FileExists(successFile) && event.ShouldDownload {
if !storage.FileExists(successFile) {
// TODO: Handle retry logic
if err := d.download(modelName, modelUri); err != nil {
return fmt.Errorf("download error: %v", err)
Expand All @@ -39,8 +38,6 @@ func (d *Downloader) DownloadModel(event EventWrapper) error {
return fmt.Errorf("create file error: %v", createErr)
}
defer file.Close()
} else if !event.ShouldDownload {
log.Println("Model", modelName, "does not need to be re-downloaded")
} else {
log.Println("Model", modelSpec.StorageURI, "exists already")
}
Expand Down
120 changes: 84 additions & 36 deletions pkg/agent/puller.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,121 @@
package agent

import (
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
v1 "github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1"
"log"
"path/filepath"
)

type OpType string

const (
Add OpType = "Add"
Remove OpType = "Remove"
)

type Puller struct {
ChannelMap map[string]Channel
Downloader Downloader
channelMap map[string]ModelChannel
completions chan *ModelOp
Downloader Downloader
}

type Channel struct {
EventChannel chan EventWrapper
type ModelOp struct {
ModelName string
Op OpType
Spec *v1.ModelSpec
}

func (p *Puller) AddModel(modelName string) Channel {
// TODO: Figure out the appropriate buffer-size for this
// TODO: Check if event Channel exists
eventChannel := make(chan EventWrapper, 20)
channel := Channel{
EventChannel: eventChannel,
func StartPuller(downloader Downloader, commands <-chan ModelOp) {
puller := Puller{
channelMap: make(map[string]ModelChannel),
completions: make(chan *ModelOp, 4),
Downloader: downloader,
}
go p.modelProcessor(modelName, channel.EventChannel)
p.ChannelMap[modelName] = channel
return p.ChannelMap[modelName]
go puller.processCommands(commands)
}

func (p *Puller) RemoveModel(modelName string) error {
channel, ok := p.ChannelMap[modelName]
if ok {
close(channel.EventChannel)
delete(p.ChannelMap, modelName)
func (p *Puller) processCommands(commands <-chan ModelOp) {
// channelMap accessed only by this goroutine
for {
select {
case modelOp, ok := <-commands:
if ok {
p.enqueueModelOp(&modelOp)
} else {
commands = nil
}
case completed := <-p.completions:
p.modelOpComplete(completed, commands == nil)
}
}
}

type ModelChannel struct {
modelOps chan *ModelOp
opsInFlight int
}

func (p *Puller) enqueueModelOp(modelOp *ModelOp) {
modelChan, ok := p.channelMap[modelOp.ModelName]
if !ok {
modelChan = ModelChannel{
modelOps: make(chan *ModelOp, 8),
}
go p.modelProcessor(modelOp.ModelName, modelChan.modelOps)
p.channelMap[modelOp.ModelName] = modelChan
}
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil {
return fmt.Errorf("failing to delete model directory: %v", err)
modelChan.opsInFlight += 1
modelChan.modelOps <- modelOp
}

func (p *Puller) modelOpComplete(modelOp *ModelOp, closed bool) {
modelChan, ok := p.channelMap[modelOp.ModelName]
if ok {
modelChan.opsInFlight -= 1
if modelChan.opsInFlight == 0 {
close(modelChan.modelOps)
delete(p.channelMap, modelOp.ModelName)
if closed && len(p.channelMap) == 0 {
// this was the final completion, close the channel
close(p.completions)
}
}
} else {
log.Println("Op completion event for model", modelOp.ModelName, "not found in channelMap")
}
return nil
}

func (p *Puller) modelProcessor(modelName string, events chan EventWrapper) {
func (p *Puller) modelProcessor(modelName string, ops <-chan *ModelOp) {
log.Println("worker for", modelName, "is initialized")
// TODO: Instead of going through each event, one-by-one, we need to drain and combine
// this is important for handling Load --> Unload requests sent in tandem
// Load --> Unload = 0 (cancel first load)
// Load --> Unload --> Load = 1 Load (cancel second load?)
for event := range events {
log.Println("worker", modelName, "started job", event)
switch event.LoadState {
case ShouldLoad:
log.Println("Should download", event.ModelSpec.StorageURI)
err := p.Downloader.DownloadModel(event)
for modelOp := range ops {
switch modelOp.Op {
case Add:
// Load
log.Println("Should download", modelOp.Spec.StorageURI)
err := p.Downloader.DownloadModel(modelName, modelOp.Spec)
if err != nil {
log.Println("worker failed on", event, "because: ", err)
log.Println("Download of model", modelName, "failed because: ", err)
} else {
// If there is an error, we will NOT send a request. As such, to know about errors, you will
// need to call the error endpoint of the puller
// TODO: Do request logic
log.Println("Now doing a request on", event)
log.Println("Now doing load request for", modelName)
}
case ShouldUnload:
case Remove:
// Unload
// TODO: Do request logic
log.Println("Now doing a request on", event)
log.Println("Now doing unload request for", modelName)
// If there is an error, we will NOT do a delete... that could be problematic
log.Println("Should unload", event.ModelName)
if err := p.RemoveModel(event.ModelName); err != nil {
log.Println("worker failed on", event, "because: ", err)
log.Println("Should unload", modelName)
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil {
log.Printf("failing to delete model directory: %v", err)
}
}
p.completions <- modelOp
}
}
60 changes: 20 additions & 40 deletions pkg/agent/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@ package agent

import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1"
"github.com/kubeflow/kfserving/pkg/constants"
"github.com/kubeflow/kfserving/pkg/modelconfig"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/resource"
"log"
"os"
"path/filepath"
"strings"
"time"
)

type Syncer struct {
Expand All @@ -24,9 +19,8 @@ type FileError error

var NoSuccessFile FileError = fmt.Errorf("no success file can be found")

func (s *Syncer) Start() {
modelDir := filepath.Clean(s.Watcher.Puller.Downloader.ModelDir)
timeNow := time.Now()
func SyncModelDir(modelDir string) (map[string]modelWrapper, error) {
modelTracker := make(map[string]modelWrapper)
err := filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
modelName := info.Name()
Expand All @@ -35,10 +29,15 @@ func (s *Syncer) Start() {
base := filepath.Base(path)
baseSplit := strings.SplitN(base, ".", 4)
if baseSplit[0] == "SUCCESS" {
if e := s.successParse(timeNow, modelName, baseSplit); e != nil {
if spec, e := successParse(modelName, baseSplit); e != nil {
return fmt.Errorf("error parsing SUCCESS file: %v", e)
} else {
modelTracker[modelName] = modelWrapper{
Spec: spec,
stale: true,
}
return nil
}
return nil
}
}
return NoSuccessFile
Expand All @@ -54,50 +53,31 @@ func (s *Syncer) Start() {
return nil
})
if err != nil {
log.Println("error in going through:", modelDir, err)
}
filePath := filepath.Join(s.Watcher.ConfigDir, constants.ModelConfigFileName)
log.Println("Syncing of", filePath)
file, err := ioutil.ReadFile(filePath)
if err != nil {
log.Println("Error in reading file", err)
} else {
modelConfigs := make(modelconfig.ModelConfigs, 0)
err = json.Unmarshal([]byte(file), &modelConfigs)
if err != nil {
log.Println("unable to marshall for modelConfig with error", err)
}
s.Watcher.ParseConfig(modelConfigs)
return nil, fmt.Errorf("error in syncing %s: %w", modelDir, err)
}
return modelTracker, nil
}

func (s *Syncer) successParse(timeNow time.Time, modelName string, baseSplit []string) error {
func successParse(modelName string, baseSplit []string) (*v1beta1.ModelSpec, error) {
storageURI, err := unhash(baseSplit[1])
errorMessage := "unable to unhash the SUCCESS file, maybe the SUCCESS file has been modified?: %v"
if err != nil {
return fmt.Errorf(errorMessage, err)
return nil, fmt.Errorf(errorMessage, err)
}
framework, err := unhash(baseSplit[2])
if err != nil {
return fmt.Errorf(errorMessage, err)
return nil, fmt.Errorf(errorMessage, err)
}
memory, err := unhash(baseSplit[3])
if err != nil {
return fmt.Errorf(errorMessage, err)
return nil, fmt.Errorf(errorMessage, err)
}
memoryResource := resource.MustParse(memory)

s.Watcher.ModelTracker[modelName] = ModelWrapper{
ModelSpec: &v1beta1.ModelSpec{
StorageURI: storageURI,
Framework: framework,
Memory: memoryResource,
},
Time: timeNow,
Stale: true,
Redownload: true,
}
return nil
return &v1beta1.ModelSpec{
StorageURI: storageURI,
Framework: framework,
Memory: memoryResource,
}, nil
}

func unhash(s string) (string, error) {
Expand Down
Loading

0 comments on commit 773b10f

Please sign in to comment.