diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go index 1dd17a3376..14d7b5c0eb 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -88,11 +88,13 @@ func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, fie var allErrs field.ErrorList var warnings admission.Warnings - if spec.NprocPerNode != nil && spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil { + if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil { elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode") nprocPerNodePath := specPath.Child("nprocPerNode") - allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String())) + if spec.NprocPerNode != nil { + allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath))) + } } allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...) return warnings, allErrs diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go index 8f2e492293..362a6d91e9 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook_test.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook_test.go @@ -218,6 +218,37 @@ func TestValidateV1PyTorchJob(t *testing.T) { field.Forbidden(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeMaster)).Child("replicas"), ""), }, }, + "Spec.ElasticPolicy.NProcPerNode are set": { + pytorchJob: &trainingoperator.PyTorchJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PyTorchJobSpec{ + ElasticPolicy: &trainingoperator.ElasticPolicy{ + NProcPerNode: ptr.To[int32](1), + }, + PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PyTorchJobReplicaTypeMaster: { + Replicas: ptr.To[int32](1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "pytorch", + Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantWarnings: admission.Warnings{ + fmt.Sprintf("%s is deprecated, use %s instead", + specPath.Child("elasticPolicy").Child("nProcPerNode"), specPath.Child("nprocPerNode")), + }, + }, "Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": { pytorchJob: &trainingoperator.PyTorchJob{ ObjectMeta: metav1.ObjectMeta{