Skip to content

Commit

Permalink
Refactor ExplainerSpec for better validation (kubeflow#406)
Browse files Browse the repository at this point in the history
* Refactor explainer for better ExplainerSpec validation

* Log error on getExplainer and getPredictor
  • Loading branch information
ariefrahmansyah authored and k8s-ci-robot committed Oct 6, 2019
1 parent c5f93da commit 0d9c115
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
50 changes: 30 additions & 20 deletions pkg/apis/serving/v1alpha2/explainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,41 @@ const (
)

// Returns a URI to the explainer. This URI is passed to the model-initializer via the ModelInitializerSourceUriInternalAnnotationKey
func (m *ExplainerSpec) GetStorageUri() string {
return getExplainerHandler(m).GetStorageUri()
func (e *ExplainerSpec) GetStorageUri() string {
explainer, err := getExplainer(e)
if err != nil {
return ""
}
return explainer.GetStorageUri()
}

func (m *ExplainerSpec) CreateExplainerContainer(modelName string, predictorHost string, config *ExplainersConfig) *v1.Container {
return getExplainerHandler(m).CreateExplainerContainer(modelName, predictorHost, config)
func (e *ExplainerSpec) CreateExplainerContainer(modelName string, predictorHost string, config *ExplainersConfig) *v1.Container {
explainer, err := getExplainer(e)
if err != nil {
return nil
}
return explainer.CreateExplainerContainer(modelName, predictorHost, config)
}

func (m *ExplainerSpec) ApplyDefaults() {
getExplainerHandler(m).ApplyDefaults()
func (e *ExplainerSpec) ApplyDefaults() {
explainer, err := getExplainer(e)
if err == nil {
explainer.ApplyDefaults()
}
}

func (m *ExplainerSpec) Validate() error {
explainer, err := makeExplainer(m)
func (e *ExplainerSpec) Validate() error {
explainer, err := getExplainer(e)
if err != nil {
return err
}
return explainer.Validate()
if err := explainer.Validate(); err != nil {
return err
}
if err := validateStorageURI(e.GetStorageUri()); err != nil {
return err
}
return nil
}

type ExplainerConfig struct {
Expand All @@ -62,16 +79,7 @@ type ExplainersConfig struct {
AlibiExplainer ExplainerConfig `json:"alibi,omitempty"`
}

func getExplainerHandler(modelSpec *ExplainerSpec) Explainer {
explainer, err := makeExplainer(modelSpec)
if err != nil {
klog.Fatal(err)
}

return explainer
}

func makeExplainer(explainerSpec *ExplainerSpec) (Explainer, error) {
func getExplainer(explainerSpec *ExplainerSpec) (Explainer, error) {
handlers := []Explainer{}
if explainerSpec.Custom != nil {
handlers = append(handlers, explainerSpec.Custom)
Expand All @@ -80,7 +88,9 @@ func makeExplainer(explainerSpec *ExplainerSpec) (Explainer, error) {
handlers = append(handlers, explainerSpec.Alibi)
}
if len(handlers) != 1 {
return nil, fmt.Errorf(ExactlyOneExplainerViolatedError)
err := fmt.Errorf(ExactlyOneExplainerViolatedError)
klog.Error(err)
return nil, err
}
return handlers[0], nil
}
5 changes: 4 additions & 1 deletion pkg/apis/serving/v1alpha2/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/kubeflow/kfserving/pkg/constants"
v1 "k8s.io/api/core/v1"
resource "k8s.io/apimachinery/pkg/api/resource"
"k8s.io/klog"
)

type Predictor interface {
Expand Down Expand Up @@ -186,7 +187,9 @@ func getPredictor(predictorSpec *PredictorSpec) (Predictor, error) {
predictors = append(predictors, predictorSpec.TensorRT)
}
if len(predictors) != 1 {
return nil, fmt.Errorf(ExactlyOnePredictorViolatedError)
err := fmt.Errorf(ExactlyOnePredictorViolatedError)
klog.Error(err)
return nil, err
}
return predictors[0], nil
}

0 comments on commit 0d9c115

Please sign in to comment.