Skip to content

Commit

Permalink
Add envvar tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <samparksandipan@gmail.com>
  • Loading branch information
sandipanpanda committed Sep 20, 2024
1 parent a86c927 commit 32aad29
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/jax/cpu-demo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def _main(argv):
)

print(
f"JAX process {jax.process_index()}/{jax.process_count()} initialized on "
f"JAX process {jax.process_index()}/{jax.process_count() - 1} initialized on "
f"{socket.gethostname()}"
)
print(f"JAX global devices:{jax.devices()}")
print(f"JAX local devices:{jax.local_devices()}")

print(jax.device_count())
print(jax.local_device_count())
print(f"JAX device count:{jax.device_count()}")
print(f"JAX local device count:{jax.local_device_count()}")

xs = jax.numpy.ones(jax.local_device_count())
print(jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs))
Expand Down
11 changes: 8 additions & 3 deletions pkg/controller.v1/jax/envvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package jax

import (
"fmt"
"errors"
"strconv"
"strings"

Expand All @@ -25,6 +25,11 @@ import (
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)

var (
errorDefaulContainerPortNotExposed = errors.New("default container port is not exposed")
errorFailedToRecognizeRank = errors.New("failed to recognize the JAXJob Rank")
)

type EnvVarGenerator interface {
Generate(job *kubeflowv1.JAXJob) ([]corev1.EnvVar, error)
}
Expand All @@ -44,7 +49,7 @@ func setPodEnv(jaxjob *kubeflowv1.JAXJob, podTemplateSpec *corev1.PodTemplateSpe

rank, err := strconv.Atoi(index)
if err != nil {
return err
return errorFailedToRecognizeRank
}
// Set PYTHONUNBUFFERED to true, to disable output buffering.
// Ref https://stackoverflow.com/questions/59812009/what-is-the-use-of-pythonunbuffered-in-docker-file.
Expand Down Expand Up @@ -98,5 +103,5 @@ func getPortFromJAXJob(job *kubeflowv1.JAXJob, rtype kubeflowv1.ReplicaType) (in
}
}
}
return -1, fmt.Errorf("port not found")
return -1, errorDefaulContainerPortNotExposed
}
138 changes: 138 additions & 0 deletions pkg/controller.v1/jax/envvar_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package jax

import (
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)

func TestSetPodEnv(t *testing.T) {
// Define some helper variables/constants for the test cases
validPort := int32(6666)
validIndex := "0"
invalidIndex := "invalid"

// Define a valid JAXJob structure
validJAXJob := &kubeflowv1.JAXJob{
ObjectMeta: metav1.ObjectMeta{Name: "test-jaxjob"},
Spec: kubeflowv1.JAXJobSpec{
JAXReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
kubeflowv1.JAXJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "jax",
Image: "docker.io/sandipanify/jaxgoogle:latest",
Ports: []corev1.ContainerPort{{
Name: kubeflowv1.JAXJobDefaultPortName,
ContainerPort: validPort,
}},
ImagePullPolicy: corev1.PullAlways,
Command: []string{
"python",
"train.py",
},
}},
},
},
},
},
},
}

// Define the test cases
cases := map[string]struct {
jaxJob *kubeflowv1.JAXJob
podTemplate *corev1.PodTemplateSpec
rtype kubeflowv1.ReplicaType
index string
wantPodEnvVars []corev1.EnvVar
wantErr error
}{
"successful environment variable setup": {
jaxJob: validJAXJob,
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: validIndex,
wantPodEnvVars: []corev1.EnvVar{
{Name: "PYTHONUNBUFFERED", Value: "1"},
{Name: "COORDINATOR_PORT", Value: strconv.Itoa(int(validPort))},
{Name: "COORDINATOR_ADDRESS", Value: "test-jaxjob-worker-0"},
{Name: "NUM_PROCESSES", Value: "1"},
{Name: "PROCESS_ID", Value: validIndex},
},
wantErr: nil,
},
"invalid index for PROCESS_ID": {
jaxJob: validJAXJob,
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: invalidIndex,
wantErr: errorFailedToRecognizeRank,
},
"missing container port in JAXJob": {
jaxJob: &kubeflowv1.JAXJob{
Spec: kubeflowv1.JAXJobSpec{
JAXReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
kubeflowv1.JAXJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "jax",
Ports: []corev1.ContainerPort{
{Name: "wrong-port", ContainerPort: 0},
},
}},
},
},
},
},
},
},
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: validIndex,
wantErr: errorDefaulContainerPortNotExposed,
},
}

// Execute the test cases
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
err := setPodEnv(tc.jaxJob, tc.podTemplate, string(tc.rtype), tc.index)

// Check if an error was expected
if diff := cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}

for i, container := range tc.podTemplate.Spec.Containers {
if diff := cmp.Diff(tc.wantPodEnvVars, container.Env); diff != "" {
t.Errorf("Unexpected env vars for container %d (-want,+got):\n%s", i, diff)
}
}

})
}
}

0 comments on commit 32aad29

Please sign in to comment.