Skip to content

Commit

Permalink
[POC] Multi-Model Puller (kubeflow#989)
Browse files Browse the repository at this point in the history
* [WIP] Beginning logic for multi-model puller

* change version to memory

* moving puller to cmd

* add downloader and requester logic

* add retry logic to donwload

* resolve intial comments

* resolve comments and rebase

* resolve comments

* resolve comments and add s3 logic

* have downloaded models be under a modelname

* resolve comments

* add unload logic

* added retry logic and further hardening for failures

* resolve comments and handle on-start

* resolve comments and reorganize

* fmt

* resolve comments

* inline puller

* update go mod

* move comment

* remove unnnecessary comment
  • Loading branch information
ifilonenko authored Aug 27, 2020
1 parent 0606993 commit 17b6c04
Show file tree
Hide file tree
Showing 13 changed files with 688 additions and 20 deletions.
21 changes: 21 additions & 0 deletions cmd/agent/agent.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Build the inference-agent binary
FROM golang:1.13.0 as builder

# Copy in the go src
WORKDIR /go/src/github.com/kubeflow/kfserving
COPY pkg/ pkg/
COPY cmd/ cmd/
COPY go.mod go.mod
COPY go.sum go.sum

RUN go mod download

# Build
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o agent ./cmd/agent

# Copy the inference-agent into a thin image
FROM gcr.io/distroless/static:latest
COPY third_party/ third_party/
WORKDIR /
COPY --from=builder /go/src/github.com/kubeflow/kfserving/agent .
ENTRYPOINT ["/agent"]
56 changes: 56 additions & 0 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package main

import (
"flag"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/kubeflow/kfserving/pkg/agent"
"github.com/kubeflow/kfserving/pkg/agent/storage"
)

var (
configDir = flag.String("config-dir", "/mnt/configs", "directory for model config files")
modelDir = flag.String("model-dir", "/mnt/models", "directory for model files")
s3Endpoint = flag.String("s3-endpoint", "", "endpoint for s3 bucket")
s3Region = flag.String("s3-region", "us-west-2", "region for s3 bucket")
)

func main() {
flag.Parse()
downloader := agent.Downloader{
ModelDir: *modelDir,
Providers: map[storage.Protocol]storage.Provider{},
}
if *s3Endpoint != "" {
sess, err := session.NewSession(&aws.Config{
Endpoint: aws.String(*s3Endpoint),
Region: aws.String(*s3Region)},
)
if err != nil {
panic(err)
}
downloader.Providers[storage.S3] = &storage.S3Provider{
Client: s3.New(sess),
}
}

watcher := agent.Watcher{
ConfigDir: *configDir,
ModelTracker: map[string]agent.ModelWrapper{},
Puller: agent.Puller{
ChannelMap: map[string]agent.Channel{},
Downloader: downloader,
},
}

syncer := agent.Syncer{
Watcher: watcher,
}

// Doing a forced sync in the case of container failures
// and for pre-filled config maps
syncer.Start()
watcher.Start()

}
10 changes: 4 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ require (
cloud.google.com/go v0.47.0 // indirect
contrib.go.opencensus.io/exporter/stackdriver v0.12.9-0.20191108183826-59d068f8d8ff // indirect
github.com/astaxie/beego v1.12.1
github.com/aws/aws-sdk-go v1.28.0 // indirect
github.com/aws/aws-sdk-go v1.28.0
github.com/beorn7/perks v1.0.1 // indirect
github.com/cloudevents/sdk-go v1.2.0
github.com/emicklei/go-restful v2.11.0+incompatible // indirect
github.com/fsnotify/fsnotify v1.4.9
github.com/getkin/kin-openapi v0.2.0
github.com/go-logr/logr v0.1.0
github.com/go-logr/zapr v0.1.1 // indirect
github.com/go-openapi/spec v0.19.4
github.com/gogo/protobuf v1.3.1
github.com/golang/groupcache v0.0.0-20191002201903-404acd9df4cc // indirect
github.com/golang/protobuf v1.4.1
github.com/google/go-cmp v0.4.0
github.com/google/go-cmp v0.5.0
github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705 // indirect
github.com/google/uuid v1.1.1
github.com/imdario/mergo v0.3.8 // indirect
Expand All @@ -38,7 +39,7 @@ require (
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect
golang.org/x/time v0.0.0-20191023065245-6d3f0bb11be5 // indirect
google.golang.org/grpc v1.27.0
google.golang.org/protobuf v1.24.0
google.golang.org/protobuf v1.25.0
istio.io/api v0.0.0-20191115173247-e1a1952e5b81
istio.io/client-go v0.0.0-20191120150049-26c62a04cdbc
istio.io/gogo-genproto v0.0.0-20191029161641-f7d19ec0141d // indirect
Expand All @@ -51,7 +52,4 @@ require (
knative.dev/pkg v0.0.0-20191217184203-cf220a867b3d
knative.dev/serving v0.11.0
sigs.k8s.io/controller-runtime v0.4.0
sigs.k8s.io/yaml v1.2.0 // indirect
)

//replace gopkg.in/fsnotify.v1 v1.4.7 => github.com/fsnotify/fsnotify v1.4.7
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/getkin/kin-openapi v0.2.0 h1:PbHHtYZpjKwZtGlIyELgA2DploRrsaXztoNNx9HjwNY=
github.com/getkin/kin-openapi v0.2.0/go.mod h1:V1z9xl9oF5Wt7v32ne4FmiF1alpS4dM6mNzoywPOXlk=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
Expand Down Expand Up @@ -244,6 +245,7 @@ github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705 h1:rsBH4vQ2gLNUKf2+82LNQ45AsYnH12Q5ZnHiZXx9LZw=
github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705/go.mod h1:yZAFP63pRshzrEYLXLGPmUt0Ay+2zdjmMN1loCnRLUk=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
Expand Down Expand Up @@ -591,6 +593,7 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down Expand Up @@ -696,6 +699,7 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA=
google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
90 changes: 90 additions & 0 deletions pkg/agent/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package agent

import (
"encoding/hex"
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
"log"
"path/filepath"
"regexp"
"strings"
)

type Downloader struct {
ModelDir string
Providers map[storage.Protocol]storage.Provider
}

var SupportedProtocols = []storage.Protocol{storage.S3}

func (d *Downloader) DownloadModel(event EventWrapper) error {
modelSpec := event.ModelSpec
modelName := event.ModelName
if modelSpec != nil {
modelUri := modelSpec.StorageURI
hashModelUri := hash(modelUri)
hashFramework := hash(modelSpec.Framework)
hashMemory := hash(modelSpec.Memory.String())
log.Println("Processing:", modelUri, "=", hashModelUri, hashFramework, hashMemory)
successFile := filepath.Join(d.ModelDir, modelName,
fmt.Sprintf("SUCCESS.%s.%s.%s", hashModelUri, hashFramework, hashMemory))
// Download if the event there is a success file and the event is one which we wish to Download
if !storage.FileExists(successFile) && event.ShouldDownload {
// TODO: Handle retry logic
if err := d.download(modelName, modelUri); err != nil {
return fmt.Errorf("download error: %v", err)
}
file, createErr := storage.Create(successFile)
if createErr != nil {
return fmt.Errorf("create file error: %v", createErr)
}
defer file.Close()
} else if !event.ShouldDownload {
log.Println("Model", modelName, "does not need to be re-downloaded")
} else {
log.Println("Model", modelSpec.StorageURI, "exists already")
}
}
return nil
}

func (d *Downloader) download(modelName string, storageUri string) error {
log.Println("Downloading: ", storageUri)
protocol, err := extractProtocol(storageUri)
if err != nil {
return fmt.Errorf("unsupported protocol: %v", err)
}
provider, ok := d.Providers[protocol]
if !ok {
return fmt.Errorf("protocol manager for %s is not initialized", protocol)
}
if err := provider.Download(d.ModelDir, modelName, storageUri); err != nil {
return fmt.Errorf("failure on download: %v", err)
}

return nil
}

func hash(s string) string {
src := []byte(s)
dst := make([]byte, hex.EncodedLen(len(src)))
hex.Encode(dst, src)
return string(dst)
}

func extractProtocol(storageURI string) (storage.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specificed for the storageUri")
}

for _, prefix := range SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")
}
73 changes: 73 additions & 0 deletions pkg/agent/puller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package agent

import (
"fmt"
"github.com/kubeflow/kfserving/pkg/agent/storage"
"log"
"path/filepath"
)

type Puller struct {
ChannelMap map[string]Channel
Downloader Downloader
}

type Channel struct {
EventChannel chan EventWrapper
}

func (p *Puller) AddModel(modelName string) Channel {
// TODO: Figure out the appropriate buffer-size for this
// TODO: Check if event Channel exists
eventChannel := make(chan EventWrapper, 20)
channel := Channel{
EventChannel: eventChannel,
}
go p.modelProcessor(modelName, channel.EventChannel)
p.ChannelMap[modelName] = channel
return p.ChannelMap[modelName]
}

func (p *Puller) RemoveModel(modelName string) error {
channel, ok := p.ChannelMap[modelName]
if ok {
close(channel.EventChannel)
delete(p.ChannelMap, modelName)
}
if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil {
return fmt.Errorf("failing to delete model directory: %v", err)
}
return nil
}

func (p *Puller) modelProcessor(modelName string, events chan EventWrapper) {
log.Println("worker for", modelName, "is initialized")
// TODO: Instead of going through each event, one-by-one, we need to drain and combine
// this is important for handling Load --> Unload requests sent in tandem
// Load --> Unload = 0 (cancel first load)
// Load --> Unload --> Load = 1 Load (cancel second load?)
for event := range events {
log.Println("worker", modelName, "started job", event)
switch event.LoadState {
case ShouldLoad:
log.Println("Should download", event.ModelSpec.StorageURI)
err := p.Downloader.DownloadModel(event)
if err != nil {
log.Println("worker failed on", event, "because: ", err)
} else {
// If there is an error, we will NOT send a request. As such, to know about errors, you will
// need to call the error endpoint of the puller
// TODO: Do request logic
log.Println("Now doing a request on", event)
}
case ShouldUnload:
// TODO: Do request logic
log.Println("Now doing a request on", event)
// If there is an error, we will NOT do a delete... that could be problematic
log.Println("Should unload", event.ModelName)
if err := p.RemoveModel(event.ModelName); err != nil {
log.Println("worker failed on", event, "because: ", err)
}
}
}
}
15 changes: 15 additions & 0 deletions pkg/agent/storage/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package storage

type Provider interface {
Download(modelDir string, modelName string, storageUri string) error
}

type Protocol string

const (
S3 Protocol = "s3://"
//GCS Protocol = "gs://"
//PVC Protocol = "pvc://"
//File Protocol = "file://"
//HTTPS Protocol = "https://"
)
Loading

0 comments on commit 17b6c04

Please sign in to comment.