forked from kubeflow/pipelines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[POC] Multi-Model Puller (kubeflow#989)
* [WIP] Beginning logic for multi-model puller * change version to memory * moving puller to cmd * add downloader and requester logic * add retry logic to donwload * resolve intial comments * resolve comments and rebase * resolve comments * resolve comments and add s3 logic * have downloaded models be under a modelname * resolve comments * add unload logic * added retry logic and further hardening for failures * resolve comments and handle on-start * resolve comments and reorganize * fmt * resolve comments * inline puller * update go mod * move comment * remove unnnecessary comment
- Loading branch information
1 parent
0606993
commit 17b6c04
Showing
13 changed files
with
688 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Build the inference-agent binary | ||
FROM golang:1.13.0 as builder | ||
|
||
# Copy in the go src | ||
WORKDIR /go/src/github.com/kubeflow/kfserving | ||
COPY pkg/ pkg/ | ||
COPY cmd/ cmd/ | ||
COPY go.mod go.mod | ||
COPY go.sum go.sum | ||
|
||
RUN go mod download | ||
|
||
# Build | ||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o agent ./cmd/agent | ||
|
||
# Copy the inference-agent into a thin image | ||
FROM gcr.io/distroless/static:latest | ||
COPY third_party/ third_party/ | ||
WORKDIR / | ||
COPY --from=builder /go/src/github.com/kubeflow/kfserving/agent . | ||
ENTRYPOINT ["/agent"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package main | ||
|
||
import ( | ||
"flag" | ||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/s3" | ||
"github.com/kubeflow/kfserving/pkg/agent" | ||
"github.com/kubeflow/kfserving/pkg/agent/storage" | ||
) | ||
|
||
var ( | ||
configDir = flag.String("config-dir", "/mnt/configs", "directory for model config files") | ||
modelDir = flag.String("model-dir", "/mnt/models", "directory for model files") | ||
s3Endpoint = flag.String("s3-endpoint", "", "endpoint for s3 bucket") | ||
s3Region = flag.String("s3-region", "us-west-2", "region for s3 bucket") | ||
) | ||
|
||
func main() { | ||
flag.Parse() | ||
downloader := agent.Downloader{ | ||
ModelDir: *modelDir, | ||
Providers: map[storage.Protocol]storage.Provider{}, | ||
} | ||
if *s3Endpoint != "" { | ||
sess, err := session.NewSession(&aws.Config{ | ||
Endpoint: aws.String(*s3Endpoint), | ||
Region: aws.String(*s3Region)}, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
downloader.Providers[storage.S3] = &storage.S3Provider{ | ||
Client: s3.New(sess), | ||
} | ||
} | ||
|
||
watcher := agent.Watcher{ | ||
ConfigDir: *configDir, | ||
ModelTracker: map[string]agent.ModelWrapper{}, | ||
Puller: agent.Puller{ | ||
ChannelMap: map[string]agent.Channel{}, | ||
Downloader: downloader, | ||
}, | ||
} | ||
|
||
syncer := agent.Syncer{ | ||
Watcher: watcher, | ||
} | ||
|
||
// Doing a forced sync in the case of container failures | ||
// and for pre-filled config maps | ||
syncer.Start() | ||
watcher.Start() | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package agent | ||
|
||
import ( | ||
"encoding/hex" | ||
"fmt" | ||
"github.com/kubeflow/kfserving/pkg/agent/storage" | ||
"log" | ||
"path/filepath" | ||
"regexp" | ||
"strings" | ||
) | ||
|
||
type Downloader struct { | ||
ModelDir string | ||
Providers map[storage.Protocol]storage.Provider | ||
} | ||
|
||
var SupportedProtocols = []storage.Protocol{storage.S3} | ||
|
||
func (d *Downloader) DownloadModel(event EventWrapper) error { | ||
modelSpec := event.ModelSpec | ||
modelName := event.ModelName | ||
if modelSpec != nil { | ||
modelUri := modelSpec.StorageURI | ||
hashModelUri := hash(modelUri) | ||
hashFramework := hash(modelSpec.Framework) | ||
hashMemory := hash(modelSpec.Memory.String()) | ||
log.Println("Processing:", modelUri, "=", hashModelUri, hashFramework, hashMemory) | ||
successFile := filepath.Join(d.ModelDir, modelName, | ||
fmt.Sprintf("SUCCESS.%s.%s.%s", hashModelUri, hashFramework, hashMemory)) | ||
// Download if the event there is a success file and the event is one which we wish to Download | ||
if !storage.FileExists(successFile) && event.ShouldDownload { | ||
// TODO: Handle retry logic | ||
if err := d.download(modelName, modelUri); err != nil { | ||
return fmt.Errorf("download error: %v", err) | ||
} | ||
file, createErr := storage.Create(successFile) | ||
if createErr != nil { | ||
return fmt.Errorf("create file error: %v", createErr) | ||
} | ||
defer file.Close() | ||
} else if !event.ShouldDownload { | ||
log.Println("Model", modelName, "does not need to be re-downloaded") | ||
} else { | ||
log.Println("Model", modelSpec.StorageURI, "exists already") | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (d *Downloader) download(modelName string, storageUri string) error { | ||
log.Println("Downloading: ", storageUri) | ||
protocol, err := extractProtocol(storageUri) | ||
if err != nil { | ||
return fmt.Errorf("unsupported protocol: %v", err) | ||
} | ||
provider, ok := d.Providers[protocol] | ||
if !ok { | ||
return fmt.Errorf("protocol manager for %s is not initialized", protocol) | ||
} | ||
if err := provider.Download(d.ModelDir, modelName, storageUri); err != nil { | ||
return fmt.Errorf("failure on download: %v", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func hash(s string) string { | ||
src := []byte(s) | ||
dst := make([]byte, hex.EncodedLen(len(src))) | ||
hex.Encode(dst, src) | ||
return string(dst) | ||
} | ||
|
||
func extractProtocol(storageURI string) (storage.Protocol, error) { | ||
if storageURI == "" { | ||
return "", fmt.Errorf("there is no storageUri supplied") | ||
} | ||
|
||
if !regexp.MustCompile("\\w+?://").MatchString(storageURI) { | ||
return "", fmt.Errorf("there is no protocol specificed for the storageUri") | ||
} | ||
|
||
for _, prefix := range SupportedProtocols { | ||
if strings.HasPrefix(storageURI, string(prefix)) { | ||
return prefix, nil | ||
} | ||
} | ||
return "", fmt.Errorf("protocol not supported for storageUri") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package agent | ||
|
||
import ( | ||
"fmt" | ||
"github.com/kubeflow/kfserving/pkg/agent/storage" | ||
"log" | ||
"path/filepath" | ||
) | ||
|
||
type Puller struct { | ||
ChannelMap map[string]Channel | ||
Downloader Downloader | ||
} | ||
|
||
type Channel struct { | ||
EventChannel chan EventWrapper | ||
} | ||
|
||
func (p *Puller) AddModel(modelName string) Channel { | ||
// TODO: Figure out the appropriate buffer-size for this | ||
// TODO: Check if event Channel exists | ||
eventChannel := make(chan EventWrapper, 20) | ||
channel := Channel{ | ||
EventChannel: eventChannel, | ||
} | ||
go p.modelProcessor(modelName, channel.EventChannel) | ||
p.ChannelMap[modelName] = channel | ||
return p.ChannelMap[modelName] | ||
} | ||
|
||
func (p *Puller) RemoveModel(modelName string) error { | ||
channel, ok := p.ChannelMap[modelName] | ||
if ok { | ||
close(channel.EventChannel) | ||
delete(p.ChannelMap, modelName) | ||
} | ||
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil { | ||
return fmt.Errorf("failing to delete model directory: %v", err) | ||
} | ||
return nil | ||
} | ||
|
||
func (p *Puller) modelProcessor(modelName string, events chan EventWrapper) { | ||
log.Println("worker for", modelName, "is initialized") | ||
// TODO: Instead of going through each event, one-by-one, we need to drain and combine | ||
// this is important for handling Load --> Unload requests sent in tandem | ||
// Load --> Unload = 0 (cancel first load) | ||
// Load --> Unload --> Load = 1 Load (cancel second load?) | ||
for event := range events { | ||
log.Println("worker", modelName, "started job", event) | ||
switch event.LoadState { | ||
case ShouldLoad: | ||
log.Println("Should download", event.ModelSpec.StorageURI) | ||
err := p.Downloader.DownloadModel(event) | ||
if err != nil { | ||
log.Println("worker failed on", event, "because: ", err) | ||
} else { | ||
// If there is an error, we will NOT send a request. As such, to know about errors, you will | ||
// need to call the error endpoint of the puller | ||
// TODO: Do request logic | ||
log.Println("Now doing a request on", event) | ||
} | ||
case ShouldUnload: | ||
// TODO: Do request logic | ||
log.Println("Now doing a request on", event) | ||
// If there is an error, we will NOT do a delete... that could be problematic | ||
log.Println("Should unload", event.ModelName) | ||
if err := p.RemoveModel(event.ModelName); err != nil { | ||
log.Println("worker failed on", event, "because: ", err) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package storage | ||
|
||
type Provider interface { | ||
Download(modelDir string, modelName string, storageUri string) error | ||
} | ||
|
||
type Protocol string | ||
|
||
const ( | ||
S3 Protocol = "s3://" | ||
//GCS Protocol = "gs://" | ||
//PVC Protocol = "pvc://" | ||
//File Protocol = "file://" | ||
//HTTPS Protocol = "https://" | ||
) |
Oops, something went wrong.