forked from kubeflow/pipelines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Puller streamlining/simplification (kubeflow#1057)
* Puller streamlining/simplification Follow-on changes to kubeflow#989 based on remaining review suggestions. - Simplified configmap change diffing - Connect watcher and puller with event channel - Have puller track in-progress ops per model via op completion channel and tie lifecycle of per-model channel+goroutine pairs to this * Minor change: fully decouple puller from watcher * Address some of the review comments The complete ModelOp struct is now passed all the way back and forth.
- Loading branch information
Showing
5 changed files
with
168 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,121 @@ | ||
package agent | ||
|
||
import ( | ||
"fmt" | ||
"github.com/kubeflow/kfserving/pkg/agent/storage" | ||
v1 "github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1" | ||
"log" | ||
"path/filepath" | ||
) | ||
|
||
type OpType string | ||
|
||
const ( | ||
Add OpType = "Add" | ||
Remove OpType = "Remove" | ||
) | ||
|
||
type Puller struct { | ||
ChannelMap map[string]Channel | ||
Downloader Downloader | ||
channelMap map[string]ModelChannel | ||
completions chan *ModelOp | ||
Downloader Downloader | ||
} | ||
|
||
type Channel struct { | ||
EventChannel chan EventWrapper | ||
type ModelOp struct { | ||
ModelName string | ||
Op OpType | ||
Spec *v1.ModelSpec | ||
} | ||
|
||
func (p *Puller) AddModel(modelName string) Channel { | ||
// TODO: Figure out the appropriate buffer-size for this | ||
// TODO: Check if event Channel exists | ||
eventChannel := make(chan EventWrapper, 20) | ||
channel := Channel{ | ||
EventChannel: eventChannel, | ||
func StartPuller(downloader Downloader, commands <-chan ModelOp) { | ||
puller := Puller{ | ||
channelMap: make(map[string]ModelChannel), | ||
completions: make(chan *ModelOp, 4), | ||
Downloader: downloader, | ||
} | ||
go p.modelProcessor(modelName, channel.EventChannel) | ||
p.ChannelMap[modelName] = channel | ||
return p.ChannelMap[modelName] | ||
go puller.processCommands(commands) | ||
} | ||
|
||
func (p *Puller) RemoveModel(modelName string) error { | ||
channel, ok := p.ChannelMap[modelName] | ||
if ok { | ||
close(channel.EventChannel) | ||
delete(p.ChannelMap, modelName) | ||
func (p *Puller) processCommands(commands <-chan ModelOp) { | ||
// channelMap accessed only by this goroutine | ||
for { | ||
select { | ||
case modelOp, ok := <-commands: | ||
if ok { | ||
p.enqueueModelOp(&modelOp) | ||
} else { | ||
commands = nil | ||
} | ||
case completed := <-p.completions: | ||
p.modelOpComplete(completed, commands == nil) | ||
} | ||
} | ||
} | ||
|
||
type ModelChannel struct { | ||
modelOps chan *ModelOp | ||
opsInFlight int | ||
} | ||
|
||
func (p *Puller) enqueueModelOp(modelOp *ModelOp) { | ||
modelChan, ok := p.channelMap[modelOp.ModelName] | ||
if !ok { | ||
modelChan = ModelChannel{ | ||
modelOps: make(chan *ModelOp, 8), | ||
} | ||
go p.modelProcessor(modelOp.ModelName, modelChan.modelOps) | ||
p.channelMap[modelOp.ModelName] = modelChan | ||
} | ||
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil { | ||
return fmt.Errorf("failing to delete model directory: %v", err) | ||
modelChan.opsInFlight += 1 | ||
modelChan.modelOps <- modelOp | ||
} | ||
|
||
func (p *Puller) modelOpComplete(modelOp *ModelOp, closed bool) { | ||
modelChan, ok := p.channelMap[modelOp.ModelName] | ||
if ok { | ||
modelChan.opsInFlight -= 1 | ||
if modelChan.opsInFlight == 0 { | ||
close(modelChan.modelOps) | ||
delete(p.channelMap, modelOp.ModelName) | ||
if closed && len(p.channelMap) == 0 { | ||
// this was the final completion, close the channel | ||
close(p.completions) | ||
} | ||
} | ||
} else { | ||
log.Println("Op completion event for model", modelOp.ModelName, "not found in channelMap") | ||
} | ||
return nil | ||
} | ||
|
||
func (p *Puller) modelProcessor(modelName string, events chan EventWrapper) { | ||
func (p *Puller) modelProcessor(modelName string, ops <-chan *ModelOp) { | ||
log.Println("worker for", modelName, "is initialized") | ||
// TODO: Instead of going through each event, one-by-one, we need to drain and combine | ||
// this is important for handling Load --> Unload requests sent in tandem | ||
// Load --> Unload = 0 (cancel first load) | ||
// Load --> Unload --> Load = 1 Load (cancel second load?) | ||
for event := range events { | ||
log.Println("worker", modelName, "started job", event) | ||
switch event.LoadState { | ||
case ShouldLoad: | ||
log.Println("Should download", event.ModelSpec.StorageURI) | ||
err := p.Downloader.DownloadModel(event) | ||
for modelOp := range ops { | ||
switch modelOp.Op { | ||
case Add: | ||
// Load | ||
log.Println("Should download", modelOp.Spec.StorageURI) | ||
err := p.Downloader.DownloadModel(modelName, modelOp.Spec) | ||
if err != nil { | ||
log.Println("worker failed on", event, "because: ", err) | ||
log.Println("Download of model", modelName, "failed because: ", err) | ||
} else { | ||
// If there is an error, we will NOT send a request. As such, to know about errors, you will | ||
// need to call the error endpoint of the puller | ||
// TODO: Do request logic | ||
log.Println("Now doing a request on", event) | ||
log.Println("Now doing load request for", modelName) | ||
} | ||
case ShouldUnload: | ||
case Remove: | ||
// Unload | ||
// TODO: Do request logic | ||
log.Println("Now doing a request on", event) | ||
log.Println("Now doing unload request for", modelName) | ||
// If there is an error, we will NOT do a delete... that could be problematic | ||
log.Println("Should unload", event.ModelName) | ||
if err := p.RemoveModel(event.ModelName); err != nil { | ||
log.Println("worker failed on", event, "because: ", err) | ||
log.Println("Should unload", modelName) | ||
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil { | ||
log.Printf("failing to delete model directory: %v", err) | ||
} | ||
} | ||
p.completions <- modelOp | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.