Skip to content

Commit

Permalink
Consolidate validation and defaulting logic (#376)
Browse files Browse the repository at this point in the history
Validation happens in a single place, improving coverage
  • Loading branch information
alculquicondor authored Jul 15, 2021
1 parent e9547bf commit e80137c
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 61 deletions.
37 changes: 26 additions & 11 deletions v2/pkg/apis/kubeflow/v2/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,42 @@ import (
"k8s.io/apimachinery/pkg/runtime"
)

// Int32 is a helper routine that allocates a new int32 value
// to store v and returns a pointer to it.
func Int32(v int32) *int32 {
return &v
}

func addDefaultingFuncs(scheme *runtime.Scheme) error {
return RegisterDefaults(scheme)
}

// setDefaultsTypeLauncher sets the default value to launcher.
func setDefaultsTypeLauncher(spec *common.ReplicaSpec) {
if spec != nil && spec.RestartPolicy == "" {
if spec == nil {
return
}
if spec.RestartPolicy == "" {
spec.RestartPolicy = DefaultRestartPolicy
}
if spec.Replicas == nil {
spec.Replicas = newInt32(1)
}
}

// setDefaultsTypeWorker sets the default value to worker.
func setDefaultsTypeWorker(spec *common.ReplicaSpec) {
if spec != nil && spec.RestartPolicy == "" {
if spec == nil {
return
}
if spec.RestartPolicy == "" {
spec.RestartPolicy = DefaultRestartPolicy
}
if spec.Replicas == nil {
spec.Replicas = newInt32(0)
}
}

func SetDefaults_MPIJob(mpiJob *MPIJob) {
// Set default cleanpod policy to None.
if mpiJob.Spec.CleanPodPolicy == nil {
none := common.CleanPodPolicyNone
mpiJob.Spec.CleanPodPolicy = &none
mpiJob.Spec.CleanPodPolicy = newCleanPodPolicy(common.CleanPodPolicyNone)
}
if mpiJob.Spec.SlotsPerWorker == nil {
mpiJob.Spec.SlotsPerWorker = newInt32(1)
}

// set default to Launcher
Expand All @@ -56,3 +63,11 @@ func SetDefaults_MPIJob(mpiJob *MPIJob) {
// set default to Worker
setDefaultsTypeWorker(mpiJob.Spec.MPIReplicaSpecs[MPIReplicaTypeWorker])
}

func newInt32(v int32) *int32 {
return &v
}

func newCleanPodPolicy(policy common.CleanPodPolicy) *common.CleanPodPolicy {
return &policy
}
103 changes: 103 additions & 0 deletions v2/pkg/apis/kubeflow/v2/default_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2021 The Kubeflow Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package v2

import (
"testing"

"github.com/google/go-cmp/cmp"
common "github.com/kubeflow/common/pkg/apis/common/v1"
)

func TestSetDefaults_MPIJob(t *testing.T) {
cases := map[string]struct {
job MPIJob
want MPIJob
}{
"base defaults": {
want: MPIJob{
Spec: MPIJobSpec{
SlotsPerWorker: newInt32(1),
CleanPodPolicy: newCleanPodPolicy(common.CleanPodPolicyNone),
},
},
},
"base defaults overridden": {
job: MPIJob{
Spec: MPIJobSpec{
SlotsPerWorker: newInt32(10),
CleanPodPolicy: newCleanPodPolicy(common.CleanPodPolicyRunning),
},
},
want: MPIJob{
Spec: MPIJobSpec{
SlotsPerWorker: newInt32(10),
CleanPodPolicy: newCleanPodPolicy(common.CleanPodPolicyRunning),
},
},
},
"launcher defaults": {
job: MPIJob{
Spec: MPIJobSpec{
MPIReplicaSpecs: map[MPIReplicaType]*common.ReplicaSpec{
MPIReplicaTypeLauncher: {},
},
},
},
want: MPIJob{
Spec: MPIJobSpec{
SlotsPerWorker: newInt32(1),
CleanPodPolicy: newCleanPodPolicy(common.CleanPodPolicyNone),
MPIReplicaSpecs: map[MPIReplicaType]*common.ReplicaSpec{
MPIReplicaTypeLauncher: {
Replicas: newInt32(1),
RestartPolicy: DefaultRestartPolicy,
},
},
},
},
},
"worker defaults": {
job: MPIJob{
Spec: MPIJobSpec{
MPIReplicaSpecs: map[MPIReplicaType]*common.ReplicaSpec{
MPIReplicaTypeWorker: {},
},
},
},
want: MPIJob{
Spec: MPIJobSpec{
SlotsPerWorker: newInt32(1),
CleanPodPolicy: newCleanPodPolicy(common.CleanPodPolicyNone),
MPIReplicaSpecs: map[MPIReplicaType]*common.ReplicaSpec{
MPIReplicaTypeWorker: {
Replicas: newInt32(0),
RestartPolicy: DefaultRestartPolicy,
},
},
},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
got := tc.job.DeepCopy()
SetDefaults_MPIJob(got)
if diff := cmp.Diff(tc.want, *got); diff != "" {
t.Errorf("Unexpected changes (-want,+got):\n%s", diff)
}
})
}
}
96 changes: 96 additions & 0 deletions v2/pkg/apis/kubeflow/validation/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright 2021 The Kubeflow Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package validation

import (
"fmt"

common "github.com/kubeflow/common/pkg/apis/common/v1"
v2 "github.com/kubeflow/mpi-operator/v2/pkg/apis/kubeflow/v2"
apivalidation "k8s.io/apimachinery/pkg/api/validation"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation/field"
)

var validCleanPolicies = sets.NewString(
string(common.CleanPodPolicyNone),
string(common.CleanPodPolicyRunning),
string(common.CleanPodPolicyAll))

func ValidateMPIJob(job *v2.MPIJob) field.ErrorList {
return validateMPIJobSpec(&job.Spec, field.NewPath("spec"))
}

func validateMPIJobSpec(spec *v2.MPIJobSpec, path *field.Path) field.ErrorList {
errs := validateMPIReplicaSpecs(spec.MPIReplicaSpecs, path.Child("mpiReplicaSpecs"))
if spec.SlotsPerWorker == nil {
errs = append(errs, field.Required(path.Child("slotsPerWorker"), "must have number of slots per worker"))
} else {
errs = append(errs, apivalidation.ValidateNonnegativeField(int64(*spec.SlotsPerWorker), path.Child("slotsPerWorker"))...)
}
if spec.CleanPodPolicy == nil {
errs = append(errs, field.Required(path.Child("cleanPodPolicy"), "must have clean Pod policy"))
} else if !validCleanPolicies.Has(string(*spec.CleanPodPolicy)) {
errs = append(errs, field.NotSupported(path.Child("cleanPodPolicy"), *spec.CleanPodPolicy, validCleanPolicies.List()))
}
return errs
}

func validateMPIReplicaSpecs(replicaSpecs map[v2.MPIReplicaType]*common.ReplicaSpec, path *field.Path) field.ErrorList {
var errs field.ErrorList
if replicaSpecs == nil {
errs = append(errs, field.Required(path, "must have replica specs"))
return errs
}
errs = append(errs, validateLauncherReplicaSpec(replicaSpecs[v2.MPIReplicaTypeLauncher], path.Key(string(v2.MPIReplicaTypeLauncher)))...)
errs = append(errs, validateWorkerReplicaSpec(replicaSpecs[v2.MPIReplicaTypeWorker], path.Key(string(v2.MPIReplicaTypeWorker)))...)
return errs
}

func validateLauncherReplicaSpec(spec *common.ReplicaSpec, path *field.Path) field.ErrorList {
var errs field.ErrorList
if spec == nil {
errs = append(errs, field.Required(path, fmt.Sprintf("must have %s replica spec", v2.MPIReplicaTypeLauncher)))
return errs
}
errs = append(errs, validateReplicaSpec(spec, path)...)
if spec.Replicas != nil && *spec.Replicas != 1 {
errs = append(errs, field.Invalid(path.Child("replicas"), *spec.Replicas, "must be 1"))
}
return errs
}

func validateWorkerReplicaSpec(spec *common.ReplicaSpec, path *field.Path) field.ErrorList {
var errs field.ErrorList
if spec == nil {
return errs
}
errs = append(errs, validateReplicaSpec(spec, path)...)
if spec.Replicas != nil {
errs = append(errs, apivalidation.ValidateNonnegativeField(int64(*spec.Replicas), path.Child("replicas"))...)
}
return errs
}

func validateReplicaSpec(spec *common.ReplicaSpec, path *field.Path) field.ErrorList {
var errs field.ErrorList
if spec.Replicas == nil {
errs = append(errs, field.Required(path.Child("replicas"), "must define number of replicas"))
}
if len(spec.Template.Spec.Containers) == 0 {
errs = append(errs, field.Required(path.Child("template", "spec", "containers"), "must define at least one container"))
}
return errs
}
Loading

0 comments on commit e80137c

Please sign in to comment.