Skip to content

Commit

Permalink
PyTorchJob: Always show warnings when using elasticPolicy.nProcPerNode (
Browse files Browse the repository at this point in the history
kubeflow#2067)

Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>
Signed-off-by: Weiyu Yen <ckyuto@gmail.com>
  • Loading branch information
tenzen-y authored and ckyuto committed Apr 26, 2024
1 parent a37ada6 commit 8f1f3c9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pkg/webhooks/pytorch/pytorchjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, fie
var allErrs field.ErrorList
var warnings admission.Warnings

if spec.NprocPerNode != nil && spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil {
if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
}
}
allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...)
return warnings, allErrs
Expand Down
31 changes: 31 additions & 0 deletions pkg/webhooks/pytorch/pytorchjob_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,37 @@ func TestValidateV1PyTorchJob(t *testing.T) {
field.Forbidden(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeMaster)).Child("replicas"), ""),
},
},
"Spec.ElasticPolicy.NProcPerNode are set": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{
NProcPerNode: ptr.To[int32](1),
},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeMaster: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantWarnings: admission.Warnings{
fmt.Sprintf("%s is deprecated, use %s instead",
specPath.Child("elasticPolicy").Child("nProcPerNode"), specPath.Child("nprocPerNode")),
},
},
"Spec.NprocPerNode and Spec.ElasticPolicy.NProcPerNode are set": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Expand Down

0 comments on commit 8f1f3c9

Please sign in to comment.