From 17b6c04bac13e99ee1830a320bc93e388d3a0237 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Thu, 27 Aug 2020 14:05:52 -0400 Subject: [PATCH] [POC] Multi-Model Puller (#989) * [WIP] Beginning logic for multi-model puller * change version to memory * moving puller to cmd * add downloader and requester logic * add retry logic to donwload * resolve intial comments * resolve comments and rebase * resolve comments * resolve comments and add s3 logic * have downloaded models be under a modelname * resolve comments * add unload logic * added retry logic and further hardening for failures * resolve comments and handle on-start * resolve comments and reorganize * fmt * resolve comments * inline puller * update go mod * move comment * remove unnnecessary comment --- cmd/agent/agent.Dockerfile | 21 ++++ cmd/agent/main.go | 56 +++++++++++ go.mod | 10 +- go.sum | 4 + pkg/agent/downloader.go | 90 +++++++++++++++++ pkg/agent/puller.go | 73 ++++++++++++++ pkg/agent/storage/provider.go | 15 +++ pkg/agent/storage/s3.go | 101 +++++++++++++++++++ pkg/agent/storage/utils.go | 42 ++++++++ pkg/agent/syncer.go | 109 ++++++++++++++++++++ pkg/agent/watcher.go | 161 ++++++++++++++++++++++++++++++ pkg/modelconfig/configmap.go | 4 +- pkg/modelconfig/configmap_test.go | 22 ++-- 13 files changed, 688 insertions(+), 20 deletions(-) create mode 100644 cmd/agent/agent.Dockerfile create mode 100644 cmd/agent/main.go create mode 100644 pkg/agent/downloader.go create mode 100644 pkg/agent/puller.go create mode 100644 pkg/agent/storage/provider.go create mode 100644 pkg/agent/storage/s3.go create mode 100644 pkg/agent/storage/utils.go create mode 100644 pkg/agent/syncer.go create mode 100644 pkg/agent/watcher.go diff --git a/cmd/agent/agent.Dockerfile b/cmd/agent/agent.Dockerfile new file mode 100644 index 00000000000..4474c6795d5 --- /dev/null +++ b/cmd/agent/agent.Dockerfile @@ -0,0 +1,21 @@ +# Build the inference-agent binary +FROM golang:1.13.0 as builder + +# Copy in the go src +WORKDIR /go/src/github.com/kubeflow/kfserving +COPY pkg/ pkg/ +COPY cmd/ cmd/ +COPY go.mod go.mod +COPY go.sum go.sum + +RUN go mod download + +# Build +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o agent ./cmd/agent + +# Copy the inference-agent into a thin image +FROM gcr.io/distroless/static:latest +COPY third_party/ third_party/ +WORKDIR / +COPY --from=builder /go/src/github.com/kubeflow/kfserving/agent . +ENTRYPOINT ["/agent"] diff --git a/cmd/agent/main.go b/cmd/agent/main.go new file mode 100644 index 00000000000..963d5457448 --- /dev/null +++ b/cmd/agent/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "flag" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/kubeflow/kfserving/pkg/agent" + "github.com/kubeflow/kfserving/pkg/agent/storage" +) + +var ( + configDir = flag.String("config-dir", "/mnt/configs", "directory for model config files") + modelDir = flag.String("model-dir", "/mnt/models", "directory for model files") + s3Endpoint = flag.String("s3-endpoint", "", "endpoint for s3 bucket") + s3Region = flag.String("s3-region", "us-west-2", "region for s3 bucket") +) + +func main() { + flag.Parse() + downloader := agent.Downloader{ + ModelDir: *modelDir, + Providers: map[storage.Protocol]storage.Provider{}, + } + if *s3Endpoint != "" { + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(*s3Endpoint), + Region: aws.String(*s3Region)}, + ) + if err != nil { + panic(err) + } + downloader.Providers[storage.S3] = &storage.S3Provider{ + Client: s3.New(sess), + } + } + + watcher := agent.Watcher{ + ConfigDir: *configDir, + ModelTracker: map[string]agent.ModelWrapper{}, + Puller: agent.Puller{ + ChannelMap: map[string]agent.Channel{}, + Downloader: downloader, + }, + } + + syncer := agent.Syncer{ + Watcher: watcher, + } + + // Doing a forced sync in the case of container failures + // and for pre-filled config maps + syncer.Start() + watcher.Start() + +} diff --git a/go.mod b/go.mod index b8ccb4d61b5..818b9256ab6 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,11 @@ require ( cloud.google.com/go v0.47.0 // indirect contrib.go.opencensus.io/exporter/stackdriver v0.12.9-0.20191108183826-59d068f8d8ff // indirect github.com/astaxie/beego v1.12.1 - github.com/aws/aws-sdk-go v1.28.0 // indirect + github.com/aws/aws-sdk-go v1.28.0 github.com/beorn7/perks v1.0.1 // indirect github.com/cloudevents/sdk-go v1.2.0 github.com/emicklei/go-restful v2.11.0+incompatible // indirect + github.com/fsnotify/fsnotify v1.4.9 github.com/getkin/kin-openapi v0.2.0 github.com/go-logr/logr v0.1.0 github.com/go-logr/zapr v0.1.1 // indirect @@ -17,7 +18,7 @@ require ( github.com/gogo/protobuf v1.3.1 github.com/golang/groupcache v0.0.0-20191002201903-404acd9df4cc // indirect github.com/golang/protobuf v1.4.1 - github.com/google/go-cmp v0.4.0 + github.com/google/go-cmp v0.5.0 github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705 // indirect github.com/google/uuid v1.1.1 github.com/imdario/mergo v0.3.8 // indirect @@ -38,7 +39,7 @@ require ( golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect golang.org/x/time v0.0.0-20191023065245-6d3f0bb11be5 // indirect google.golang.org/grpc v1.27.0 - google.golang.org/protobuf v1.24.0 + google.golang.org/protobuf v1.25.0 istio.io/api v0.0.0-20191115173247-e1a1952e5b81 istio.io/client-go v0.0.0-20191120150049-26c62a04cdbc istio.io/gogo-genproto v0.0.0-20191029161641-f7d19ec0141d // indirect @@ -51,7 +52,4 @@ require ( knative.dev/pkg v0.0.0-20191217184203-cf220a867b3d knative.dev/serving v0.11.0 sigs.k8s.io/controller-runtime v0.4.0 - sigs.k8s.io/yaml v1.2.0 // indirect ) - -//replace gopkg.in/fsnotify.v1 v1.4.7 => github.com/fsnotify/fsnotify v1.4.7 diff --git a/go.sum b/go.sum index bd54fb53c1d..a82c2b1db6b 100644 --- a/go.sum +++ b/go.sum @@ -133,6 +133,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/getkin/kin-openapi v0.2.0 h1:PbHHtYZpjKwZtGlIyELgA2DploRrsaXztoNNx9HjwNY= github.com/getkin/kin-openapi v0.2.0/go.mod h1:V1z9xl9oF5Wt7v32ne4FmiF1alpS4dM6mNzoywPOXlk= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -244,6 +245,7 @@ github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705 h1:rsBH4vQ2gLNUKf2+82LNQ45AsYnH12Q5ZnHiZXx9LZw= github.com/google/go-containerregistry v0.0.0-20190910142231-b02d448a3705/go.mod h1:yZAFP63pRshzrEYLXLGPmUt0Ay+2zdjmMN1loCnRLUk= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= @@ -591,6 +593,7 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -696,6 +699,7 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/agent/downloader.go b/pkg/agent/downloader.go new file mode 100644 index 00000000000..7f860079f2f --- /dev/null +++ b/pkg/agent/downloader.go @@ -0,0 +1,90 @@ +package agent + +import ( + "encoding/hex" + "fmt" + "github.com/kubeflow/kfserving/pkg/agent/storage" + "log" + "path/filepath" + "regexp" + "strings" +) + +type Downloader struct { + ModelDir string + Providers map[storage.Protocol]storage.Provider +} + +var SupportedProtocols = []storage.Protocol{storage.S3} + +func (d *Downloader) DownloadModel(event EventWrapper) error { + modelSpec := event.ModelSpec + modelName := event.ModelName + if modelSpec != nil { + modelUri := modelSpec.StorageURI + hashModelUri := hash(modelUri) + hashFramework := hash(modelSpec.Framework) + hashMemory := hash(modelSpec.Memory.String()) + log.Println("Processing:", modelUri, "=", hashModelUri, hashFramework, hashMemory) + successFile := filepath.Join(d.ModelDir, modelName, + fmt.Sprintf("SUCCESS.%s.%s.%s", hashModelUri, hashFramework, hashMemory)) + // Download if the event there is a success file and the event is one which we wish to Download + if !storage.FileExists(successFile) && event.ShouldDownload { + // TODO: Handle retry logic + if err := d.download(modelName, modelUri); err != nil { + return fmt.Errorf("download error: %v", err) + } + file, createErr := storage.Create(successFile) + if createErr != nil { + return fmt.Errorf("create file error: %v", createErr) + } + defer file.Close() + } else if !event.ShouldDownload { + log.Println("Model", modelName, "does not need to be re-downloaded") + } else { + log.Println("Model", modelSpec.StorageURI, "exists already") + } + } + return nil +} + +func (d *Downloader) download(modelName string, storageUri string) error { + log.Println("Downloading: ", storageUri) + protocol, err := extractProtocol(storageUri) + if err != nil { + return fmt.Errorf("unsupported protocol: %v", err) + } + provider, ok := d.Providers[protocol] + if !ok { + return fmt.Errorf("protocol manager for %s is not initialized", protocol) + } + if err := provider.Download(d.ModelDir, modelName, storageUri); err != nil { + return fmt.Errorf("failure on download: %v", err) + } + + return nil +} + +func hash(s string) string { + src := []byte(s) + dst := make([]byte, hex.EncodedLen(len(src))) + hex.Encode(dst, src) + return string(dst) +} + +func extractProtocol(storageURI string) (storage.Protocol, error) { + if storageURI == "" { + return "", fmt.Errorf("there is no storageUri supplied") + } + + if !regexp.MustCompile("\\w+?://").MatchString(storageURI) { + return "", fmt.Errorf("there is no protocol specificed for the storageUri") + } + + for _, prefix := range SupportedProtocols { + if strings.HasPrefix(storageURI, string(prefix)) { + return prefix, nil + } + } + return "", fmt.Errorf("protocol not supported for storageUri") +} diff --git a/pkg/agent/puller.go b/pkg/agent/puller.go new file mode 100644 index 00000000000..95411c67267 --- /dev/null +++ b/pkg/agent/puller.go @@ -0,0 +1,73 @@ +package agent + +import ( + "fmt" + "github.com/kubeflow/kfserving/pkg/agent/storage" + "log" + "path/filepath" +) + +type Puller struct { + ChannelMap map[string]Channel + Downloader Downloader +} + +type Channel struct { + EventChannel chan EventWrapper +} + +func (p *Puller) AddModel(modelName string) Channel { + // TODO: Figure out the appropriate buffer-size for this + // TODO: Check if event Channel exists + eventChannel := make(chan EventWrapper, 20) + channel := Channel{ + EventChannel: eventChannel, + } + go p.modelProcessor(modelName, channel.EventChannel) + p.ChannelMap[modelName] = channel + return p.ChannelMap[modelName] +} + +func (p *Puller) RemoveModel(modelName string) error { + channel, ok := p.ChannelMap[modelName] + if ok { + close(channel.EventChannel) + delete(p.ChannelMap, modelName) + } + if err := storage.RemoveDir(filepath.Join(p.Downloader.ModelDir, modelName)); err != nil { + return fmt.Errorf("failing to delete model directory: %v", err) + } + return nil +} + +func (p *Puller) modelProcessor(modelName string, events chan EventWrapper) { + log.Println("worker for", modelName, "is initialized") + // TODO: Instead of going through each event, one-by-one, we need to drain and combine + // this is important for handling Load --> Unload requests sent in tandem + // Load --> Unload = 0 (cancel first load) + // Load --> Unload --> Load = 1 Load (cancel second load?) + for event := range events { + log.Println("worker", modelName, "started job", event) + switch event.LoadState { + case ShouldLoad: + log.Println("Should download", event.ModelSpec.StorageURI) + err := p.Downloader.DownloadModel(event) + if err != nil { + log.Println("worker failed on", event, "because: ", err) + } else { + // If there is an error, we will NOT send a request. As such, to know about errors, you will + // need to call the error endpoint of the puller + // TODO: Do request logic + log.Println("Now doing a request on", event) + } + case ShouldUnload: + // TODO: Do request logic + log.Println("Now doing a request on", event) + // If there is an error, we will NOT do a delete... that could be problematic + log.Println("Should unload", event.ModelName) + if err := p.RemoveModel(event.ModelName); err != nil { + log.Println("worker failed on", event, "because: ", err) + } + } + } +} diff --git a/pkg/agent/storage/provider.go b/pkg/agent/storage/provider.go new file mode 100644 index 00000000000..e5f534357b0 --- /dev/null +++ b/pkg/agent/storage/provider.go @@ -0,0 +1,15 @@ +package storage + +type Provider interface { + Download(modelDir string, modelName string, storageUri string) error +} + +type Protocol string + +const ( + S3 Protocol = "s3://" + //GCS Protocol = "gs://" + //PVC Protocol = "pvc://" + //File Protocol = "file://" + //HTTPS Protocol = "https://" +) diff --git a/pkg/agent/storage/s3.go b/pkg/agent/storage/s3.go new file mode 100644 index 00000000000..1c37cd6e903 --- /dev/null +++ b/pkg/agent/storage/s3.go @@ -0,0 +1,101 @@ +package storage + +import ( + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "log" + "os" + "path/filepath" + "strings" +) + +type S3Provider struct { + Client *s3.S3 +} + +func (m *S3Provider) Download(modelDir string, modelName string, storageUri string) error { + s3Uri := strings.TrimPrefix(storageUri, string(S3)) + path := strings.Split(s3Uri, "/") + s3ObjectDownloader := &S3ObjectDownloader{ + StorageUri: storageUri, + ModelDir: modelDir, + ModelName: modelName, + Bucket: path[0], + Item: path[1], + } + objects, err := s3ObjectDownloader.GetAllObjects(m.Client) + if err != nil { + return fmt.Errorf("unable to get batch objects %v", err) + } + if err := s3ObjectDownloader.Download(m.Client, objects); err != nil { + return fmt.Errorf("unable to get download objects %v", err) + } + return nil +} + +var _ Provider = (*S3Provider)(nil) + +type S3ObjectDownloader struct { + StorageUri string + ModelDir string + ModelName string + Bucket string + Item string +} + +func (s *S3ObjectDownloader) GetAllObjects(s3Svc *s3.S3) ([]s3manager.BatchDownloadObject, error) { + resp, err := s3Svc.ListObjects(&s3.ListObjectsInput{ + Bucket: aws.String(s.Bucket), + Prefix: aws.String(s.Item), + }) + if err != nil { + return nil, err + } + results := make([]s3manager.BatchDownloadObject, 0) + + if len(resp.Contents) == 0 { + return nil, fmt.Errorf("%s has no objects or does not exist", s.StorageUri) + } + + for _, object := range resp.Contents { + fileName := filepath.Join(s.ModelDir, s.ModelName, *object.Key) + if FileExists(fileName) { + // File got corrupted or is mid-download :( + // TODO: Figure out if we can maybe continue? + log.Println("Deleting", fileName) + if err := os.Remove(fileName); err != nil { + return nil, fmt.Errorf("file is unable to be deleted: %v", err) + } + } + file, err := Create(fileName) + if err != nil { + return nil, fmt.Errorf("file is already created: %v", err) + } + object := s3manager.BatchDownloadObject{ + Object: &s3.GetObjectInput{ + Key: aws.String(*object.Key), + Bucket: aws.String(s.Bucket), + }, + Writer: file, + After: func() error { + defer file.Close() + return nil + }, + } + results = append(results, object) + } + return results, nil +} + +func (s *S3ObjectDownloader) Download(s3Svc *s3.S3, objects []s3manager.BatchDownloadObject) error { + iter := &s3manager.DownloadObjectsIterator{Objects: objects} + downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) { + // TODO: Consider to do overrides + }) + if err := downloader.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil { + return err + } + return nil +} diff --git a/pkg/agent/storage/utils.go b/pkg/agent/storage/utils.go new file mode 100644 index 00000000000..44e01b4c159 --- /dev/null +++ b/pkg/agent/storage/utils.go @@ -0,0 +1,42 @@ +package storage + +import ( + "fmt" + "os" + "path/filepath" +) + +func FileExists(filename string) bool { + info, err := os.Stat(filename) + return !os.IsNotExist(err) && !info.IsDir() +} + +func Create(fileName string) (*os.File, error) { + if err := os.MkdirAll(filepath.Dir(fileName), 0770); err != nil { + return nil, err + } + return os.Create(fileName) +} + +func RemoveDir(dir string) error { + d, err := os.Open(dir) + if err != nil { + return err + } + defer d.Close() + names, err := d.Readdirnames(-1) + if err != nil { + return err + } + for _, name := range names { + err = os.RemoveAll(filepath.Join(dir, name)) + if err != nil { + return err + } + } + // Remove empty dir + if err := os.Remove(dir); err != nil { + return fmt.Errorf("dir is unable to be deleted: %v", err) + } + return nil +} diff --git a/pkg/agent/syncer.go b/pkg/agent/syncer.go new file mode 100644 index 00000000000..de3c25a95aa --- /dev/null +++ b/pkg/agent/syncer.go @@ -0,0 +1,109 @@ +package agent + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1" + "github.com/kubeflow/kfserving/pkg/constants" + "github.com/kubeflow/kfserving/pkg/modelconfig" + "io/ioutil" + "k8s.io/apimachinery/pkg/api/resource" + "log" + "os" + "path/filepath" + "strings" + "time" +) + +type Syncer struct { + Watcher Watcher +} + +type FileError error + +var NoSuccessFile FileError = fmt.Errorf("no success file can be found") + +func (s *Syncer) Start() { + modelDir := filepath.Clean(s.Watcher.Puller.Downloader.ModelDir) + timeNow := time.Now() + err := filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + modelName := info.Name() + ierr := filepath.Walk(path, func(path string, f os.FileInfo, _ error) error { + if !f.IsDir() { + base := filepath.Base(path) + baseSplit := strings.SplitN(base, ".", 4) + if baseSplit[0] == "SUCCESS" { + if e := s.successParse(timeNow, modelName, baseSplit); e != nil { + return fmt.Errorf("error parsing SUCCESS file: %v", e) + } + return nil + } + } + return NoSuccessFile + }) + switch ierr { + case NoSuccessFile: + return nil + default: + log.Println("failed to parse SUCCESS file:", ierr) + return ierr + } + } + return nil + }) + if err != nil { + log.Println("error in going through:", modelDir, err) + } + filePath := filepath.Join(s.Watcher.ConfigDir, constants.ModelConfigFileName) + log.Println("Syncing of", filePath) + file, err := ioutil.ReadFile(filePath) + if err != nil { + log.Println("Error in reading file", err) + } else { + modelConfigs := make(modelconfig.ModelConfigs, 0) + err = json.Unmarshal([]byte(file), &modelConfigs) + if err != nil { + log.Println("unable to marshall for modelConfig with error", err) + } + s.Watcher.ParseConfig(modelConfigs) + } +} + +func (s *Syncer) successParse(timeNow time.Time, modelName string, baseSplit []string) error { + storageURI, err := unhash(baseSplit[1]) + errorMessage := "unable to unhash the SUCCESS file, maybe the SUCCESS file has been modified?: %v" + if err != nil { + return fmt.Errorf(errorMessage, err) + } + framework, err := unhash(baseSplit[2]) + if err != nil { + return fmt.Errorf(errorMessage, err) + } + memory, err := unhash(baseSplit[3]) + if err != nil { + return fmt.Errorf(errorMessage, err) + } + memoryResource := resource.MustParse(memory) + + s.Watcher.ModelTracker[modelName] = ModelWrapper{ + ModelSpec: &v1beta1.ModelSpec{ + StorageURI: storageURI, + Framework: framework, + Memory: memoryResource, + }, + Time: timeNow, + Stale: true, + Redownload: true, + } + return nil +} + +func unhash(s string) (string, error) { + decoded, err := hex.DecodeString(s) + if err != nil { + return "", nil + } + return string(decoded), nil +} diff --git a/pkg/agent/watcher.go b/pkg/agent/watcher.go new file mode 100644 index 00000000000..32137258aa3 --- /dev/null +++ b/pkg/agent/watcher.go @@ -0,0 +1,161 @@ +package agent + +import ( + "encoding/json" + "github.com/fsnotify/fsnotify" + "github.com/google/go-cmp/cmp" + "github.com/kubeflow/kfserving/pkg/apis/serving/v1beta1" + "github.com/kubeflow/kfserving/pkg/constants" + "github.com/kubeflow/kfserving/pkg/modelconfig" + "io/ioutil" + "log" + "path/filepath" + "time" +) + +type Watcher struct { + ConfigDir string + ModelTracker map[string]ModelWrapper + Puller Puller +} + +type LoadState string + +const ( + // State Related + ShouldLoad LoadState = "Load" + ShouldUnload LoadState = "Unload" +) + +type EventWrapper struct { + ModelName string + ModelSpec *v1beta1.ModelSpec + LoadState LoadState + ShouldDownload bool +} + +type ModelWrapper struct { + ModelSpec *v1beta1.ModelSpec + Time time.Time + Stale bool + Redownload bool +} + +func (w *Watcher) Start() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal(err) + } + defer watcher.Close() + + done := make(chan bool) + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + isCreate := event.Op&fsnotify.Create != 0 + eventPath := filepath.Clean(event.Name) + isDataDir := filepath.Base(eventPath) == "..data" + // TODO: Should we use atomic integer or timestamp?? + if isDataDir && isCreate { + symlink, _ := filepath.EvalSymlinks(eventPath) + file, err := ioutil.ReadFile(filepath.Join(symlink, constants.ModelConfigFileName)) + if err != nil { + log.Println("Error in reading file", err) + } else { + modelConfigs := make(modelconfig.ModelConfigs, 0) + err = json.Unmarshal([]byte(file), &modelConfigs) + if err != nil { + log.Println("unable to marshall for", event, "with error", err) + } + w.ParseConfig(modelConfigs) + } + } + case err, ok := <-watcher.Errors: + if ok { // 'Errors' channel is not closed + log.Println("watcher error", err) + } + if !ok { + return + } + } + } + }() + err = watcher.Add(w.ConfigDir) + if err != nil { + log.Fatal(err) + } + log.Println("Watching", w.ConfigDir) + <-done +} + +func (w *Watcher) ParseConfig(modelConfigs modelconfig.ModelConfigs) { + timeNow := time.Now() + for _, modelConfig := range modelConfigs { + modelName := modelConfig.Name + modelSpec := modelConfig.Spec + log.Println("Name:", modelName, "Spec:", modelSpec) + oldModel, ok := w.ModelTracker[modelName] + if !ok { + w.ModelTracker[modelName] = ModelWrapper{ + ModelSpec: &modelSpec, + Time: timeNow, + Stale: true, + Redownload: true, + } + } else { + isStale := true + reDownload := true + if oldModel.ModelSpec != nil { + isStale = !cmp.Equal(*oldModel.ModelSpec, modelSpec) + reDownload = !cmp.Equal(oldModel.ModelSpec.StorageURI, modelSpec.StorageURI) + log.Println("same", !isStale, *oldModel.ModelSpec, modelSpec) + } + // Need to store new time, TODO: maybe worth to have seperate map? + w.ModelTracker[modelName] = ModelWrapper{ + ModelSpec: &modelSpec, + Time: timeNow, + Stale: isStale, + Redownload: reDownload, + } + } + } + // TODO: Maybe make parallel and more efficient? + for modelName, modelWrapper := range w.ModelTracker { + if modelWrapper.Time.Before(timeNow) { + delete(w.ModelTracker, modelName) + channel, ok := w.Puller.ChannelMap[modelName] + if !ok { + log.Println("Model", modelName, "was never added to channel map") + } else { + event := EventWrapper{ + ModelName: modelName, + ModelSpec: nil, + LoadState: ShouldUnload, + ShouldDownload: false, + } + log.Println("Sending event", event) + channel.EventChannel <- event + } + } else { + if modelWrapper.Stale { + channel, ok := w.Puller.ChannelMap[modelName] + if !ok { + log.Println("Need to add model", modelName) + channel = w.Puller.AddModel(modelName) + } + event := EventWrapper{ + ModelName: modelName, + ModelSpec: modelWrapper.ModelSpec, + LoadState: ShouldLoad, + ShouldDownload: modelWrapper.Redownload, + } + log.Println("Sending event", event) + channel.EventChannel <- event + } + } + } +} diff --git a/pkg/modelconfig/configmap.go b/pkg/modelconfig/configmap.go index d0b3a40a2f0..783006479ef 100644 --- a/pkg/modelconfig/configmap.go +++ b/pkg/modelconfig/configmap.go @@ -13,8 +13,8 @@ var logger = log.Log.WithName("ModelConfig") var json = jsoniter.ConfigCompatibleWithStandardLibrary type ModelConfig struct { - Name string `json:"modelName"` - Spec v1beta1.ModelSpec `json:"modelSpec"` + Name string `json:"modelName"` + Spec v1beta1.ModelSpec `json:"modelSpec"` } type ModelConfigs []ModelConfig diff --git a/pkg/modelconfig/configmap_test.go b/pkg/modelconfig/configmap_test.go index aa2d8f328c4..b262e4ecf15 100644 --- a/pkg/modelconfig/configmap_test.go +++ b/pkg/modelconfig/configmap_test.go @@ -74,7 +74,7 @@ func TestProcessAddOrUpdate(t *testing.T) { }, }, expected: `[{"modelName":"model1","modelSpec":{"storageUri":"s3//model1","framework":"framework1","memory":"0"}},` + - `{"modelName":"model2","modelSpec":{"storageUri":"s3//model2","framework":"framework2","memory":"0"}}]`, + `{"modelName":"model2","modelSpec":{"storageUri":"s3//model2","framework":"framework2","memory":"0"}}]`, }, "update": { modelConfigs: ModelConfigs{ @@ -91,7 +91,7 @@ func TestProcessAddOrUpdate(t *testing.T) { }, }, expected: `[{"modelName":"model1","modelSpec":{"storageUri":"s3//new-model1","framework":"new-framework1","memory":"0"}},` + - `{"modelName":"model2","modelSpec":{"storageUri":"s3//model2","framework":"framework2","memory":"0"}}]`, + `{"modelName":"model2","modelSpec":{"storageUri":"s3//model2","framework":"framework2","memory":"0"}}]`, }, } for _, tc := range testCases { @@ -108,9 +108,9 @@ func TestProcessAddOrUpdate(t *testing.T) { func TestProcessDelete(t *testing.T) { log.SetLogger(log.ZapLogger(true)) testCases := map[string]struct { - modelConfigs []string - configMap *v1.ConfigMap - expected string + modelConfigs []string + configMap *v1.ConfigMap + expected string }{ "delete nil data": { modelConfigs: []string{"model1"}, @@ -174,10 +174,10 @@ func TestProcessDelete(t *testing.T) { func TestProcess(t *testing.T) { log.SetLogger(log.ZapLogger(true)) testCases := map[string]struct { - updated ModelConfigs - deleted []string - configMap *v1.ConfigMap - expected string + updated ModelConfigs + deleted []string + configMap *v1.ConfigMap + expected string }{ "process configmap": { updated: ModelConfigs{ @@ -194,14 +194,12 @@ func TestProcess(t *testing.T) { configMap: &v1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{Name: "test-config", Namespace: "test"}, Data: map[string]string{ - constants.ModelConfigFileName: - `[{"modelName":"model1","modelSpec":{"storageUri":"s3//model1","framework":"framework1","memory":"0"}},` + + constants.ModelConfigFileName: `[{"modelName":"model1","modelSpec":{"storageUri":"s3//model1","framework":"framework1","memory":"0"}},` + `{"modelName":"model2","modelSpec":{"storageUri":"s3//model2","framework":"framework2","memory":"0"}}]`, }, }, expected: `[{"modelName":"model1","modelSpec":{"storageUri":"s3//new-model1","framework":"new-framework1","memory":"0"}},` + `{"modelName":"model3","modelSpec":{"storageUri":"s3//model3","framework":"framework3","memory":"0"}}]`, - }, } for _, tc := range testCases {