diff --git a/pkg/apis/pytorch/v1/defaults.go b/pkg/apis/pytorch/v1/defaults.go index 13ac7d44e..d8504d55e 100644 --- a/pkg/apis/pytorch/v1/defaults.go +++ b/pkg/apis/pytorch/v1/defaults.go @@ -95,10 +95,12 @@ func SetDefaults_PyTorchJob(job *PyTorchJob) { // Update the key of PyTorchReplicaSpecs to camel case. setTypeNamesToCamelCase(job) - for _, spec := range job.Spec.PyTorchReplicaSpecs { + for rType, spec := range job.Spec.PyTorchReplicaSpecs { // Set default replicas to 1. setDefaultReplicas(spec) - // Set default port to pytorch container. - setDefaultPort(&spec.Template.Spec) + if rType == PyTorchReplicaTypeMaster { + // Set default port to pytorch container of Master. + setDefaultPort(&spec.Template.Spec) + } } } diff --git a/pkg/controller.v1/pytorch/controller.go b/pkg/controller.v1/pytorch/controller.go index 14408256a..b87a267e8 100644 --- a/pkg/controller.v1/pytorch/controller.go +++ b/pkg/controller.v1/pytorch/controller.go @@ -455,6 +455,10 @@ func (pc *PyTorchController) reconcilePyTorchJobs(job *pyv1.PyTorchJob) error { return err } + // Service is in need only for Master + if rtype != pyv1.PyTorchReplicaTypeMaster { + continue + } err = pc.reconcileServices(job, services, rtype, spec) if err != nil { diff --git a/pkg/controller.v1/pytorch/controller_test.go b/pkg/controller.v1/pytorch/controller_test.go index dc2ad24de..632e67388 100644 --- a/pkg/controller.v1/pytorch/controller_test.go +++ b/pkg/controller.v1/pytorch/controller_test.go @@ -128,7 +128,7 @@ func TestNormalPath(t *testing.T) { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 5, 0, 5, + 5, 0, 1, 0, 0, 0, 0, 0, 0, nil, "", @@ -139,7 +139,7 @@ func TestNormalPath(t *testing.T) { nil, true, 4, 0, 0, 0, 1, 0, 0, 0, - 4, 1, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -151,7 +151,7 @@ func TestNormalPath(t *testing.T) { nil, true, 3, 1, 0, 0, 0, 1, 0, 0, - 4, 1, + 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, @@ -163,7 +163,7 @@ func TestNormalPath(t *testing.T) { nil, true, 0, 4, 0, 0, 0, 1, 0, 0, - 4, 1, + 0, 1, 0, 0, 0, 4, 0, 0, 1, 0, 0, @@ -175,7 +175,7 @@ func TestNormalPath(t *testing.T) { nil, true, 0, 0, 4, 0, 0, 0, 1, 0, - 4, 1, + 0, 1, 0, 0, 0, 0, 4, 0, 0, 1, 0,