From a86c927ed7d0262c1e5ce651adb1b6d9333d0fd6 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Mon, 5 Aug 2024 21:28:00 +0530 Subject: [PATCH] Add JAX controller Add JAX controller, controller tests, webhook validations, examples, e2e tests for JAXJob Extend the Training Operator Python SDK to simplify the creation and management of JAXJob resources. Signed-off-by: Sandipan Panda --- .github/workflows/publish-example-images.yaml | 4 + PROJECT | 8 + cmd/training-operator.v1/main.go | 2 +- examples/jax/cpu-demo/Dockerfile | 25 + examples/jax/cpu-demo/demo.yaml | 19 + examples/jax/cpu-demo/train.py | 43 ++ manifests/base/crds/kustomization.yaml | 1 + manifests/base/rbac/role.yaml | 26 + manifests/base/webhook/manifests.yaml | 20 + manifests/base/webhook/patch.yaml | 3 + pkg/controller.v1/jax/envvar.go | 102 ++++ pkg/controller.v1/jax/jaxjob_controller.go | 488 ++++++++++++++++++ .../jax/jaxjob_controller_suite_test.go | 125 +++++ .../jax/jaxjob_controller_test.go | 316 ++++++++++++ pkg/controller.v1/register_controller.go | 4 + pkg/webhooks/jax/jaxjob_webhook.go | 124 +++++ pkg/webhooks/jax/jaxjob_webhook_test.go | 198 +++++++ pkg/webhooks/webhooks.go | 2 + .../kubeflow/training/api/training_client.py | 8 + .../kubeflow/training/constants/constants.py | 14 + sdk/python/test/e2e/test_e2e_jaxjob.py | 161 ++++++ 21 files changed, 1692 insertions(+), 1 deletion(-) create mode 100644 examples/jax/cpu-demo/Dockerfile create mode 100644 examples/jax/cpu-demo/demo.yaml create mode 100644 examples/jax/cpu-demo/train.py create mode 100644 pkg/controller.v1/jax/envvar.go create mode 100644 pkg/controller.v1/jax/jaxjob_controller.go create mode 100644 pkg/controller.v1/jax/jaxjob_controller_suite_test.go create mode 100644 pkg/controller.v1/jax/jaxjob_controller_test.go create mode 100644 pkg/webhooks/jax/jaxjob_webhook.go create mode 100644 pkg/webhooks/jax/jaxjob_webhook_test.go create mode 100644 sdk/python/test/e2e/test_e2e_jaxjob.py diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index 5b0902a8bb..9f25e59939 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -69,3 +69,7 @@ jobs: platforms: linux/amd64,linux/arm64 dockerfile: examples/pytorch/mnist/Dockerfile-mpi context: examples/pytorch/mnist + - component-name: jaxjob-simple + platforms: linux/amd64,linux/arm64 + dockerfile: examples/jax/cpu-demo/Dockerfile + context: examples/jax/cpu-demo diff --git a/PROJECT b/PROJECT index 8f321f6b86..4aea9cdea0 100644 --- a/PROJECT +++ b/PROJECT @@ -27,4 +27,12 @@ resources: kind: TFJob path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1 version: v1 +- api: + crdVersion: v1 + namespaced: true + controller: true + group: kubeflow.org + kind: JAXJob + path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1 + version: v1 version: "3" diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index bec5cd6b55..e008c05d4e 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -87,7 +87,7 @@ func main() { "Enabling this will ensure there is only one active controller manager.") flag.StringVar(&leaderElectionID, "leader-election-id", "1ca428e5.training-operator.kubeflow.org", "The ID for leader election.") flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+ - " Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob. By default, all supported schemes will be enabled.") + " Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob, JAXJob. By default, all supported schemes will be enabled.") flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "", "Now Supporting volcano and scheduler-plugins."+ " Note: If you set another scheduler name, the training-operator assumes it's the scheduler-plugins.") flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ diff --git a/examples/jax/cpu-demo/Dockerfile b/examples/jax/cpu-demo/Dockerfile new file mode 100644 index 0000000000..6e2f59b834 --- /dev/null +++ b/examples/jax/cpu-demo/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.12 + +RUN pip install jax absl-py kubernetes + +RUN apt-get update && apt-get install -y \ + build-essential \ + cmake \ + git \ + libgoogle-glog-dev \ + libgflags-dev \ + libprotobuf-dev \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* + +RUN git clone https://github.com/facebookincubator/gloo.git \ + && cd gloo \ + && mkdir build \ + && cd build \ + && cmake ../ \ + && make \ + && make install + +WORKDIR /app + +ADD train.py /app diff --git a/examples/jax/cpu-demo/demo.yaml b/examples/jax/cpu-demo/demo.yaml new file mode 100644 index 0000000000..85c99c9b18 --- /dev/null +++ b/examples/jax/cpu-demo/demo.yaml @@ -0,0 +1,19 @@ +apiVersion: "kubeflow.org/v1" +kind: JAXJob +metadata: + name: jaxjob-simple + namespace: kubeflow +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - name: jax + image: docker.io/sandipanify/jaxgoogle:latest + command: + - "python3" + - "train.py" + imagePullPolicy: Always diff --git a/examples/jax/cpu-demo/train.py b/examples/jax/cpu-demo/train.py new file mode 100644 index 0000000000..0383374555 --- /dev/null +++ b/examples/jax/cpu-demo/train.py @@ -0,0 +1,43 @@ +# example ref: +# https://jax.readthedocs.io/en/latest/multi_process.html#running-multi-process-computations +# https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/tutorials-and-examples/gpu-examples/a100-jax/train.py # noqa + +import os +import socket + +import jax +from absl import app + +jax.config.update("jax_cpu_collectives_implementation", "gloo") + + +def _main(argv): + + process_id = int(os.getenv("PROCESS_ID")) + num_processes = int(os.getenv("NUM_PROCESSES")) + coordinator_address = os.getenv("COORDINATOR_ADDRESS") + coordinator_port = int(os.getenv("COORDINATOR_PORT")) + coordinator_address = f"{coordinator_address}:{coordinator_port}" + + jax.distributed.initialize( + coordinator_address=coordinator_address, + num_processes=num_processes, + process_id=process_id, + ) + + print( + f"JAX process {jax.process_index()}/{jax.process_count()} initialized on " + f"{socket.gethostname()}" + ) + print(f"JAX global devices:{jax.devices()}") + print(f"JAX local devices:{jax.local_devices()}") + + print(jax.device_count()) + print(jax.local_device_count()) + + xs = jax.numpy.ones(jax.local_device_count()) + print(jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs)) + + +if __name__ == "__main__": + app.run(_main) diff --git a/manifests/base/crds/kustomization.yaml b/manifests/base/crds/kustomization.yaml index 035ab181ea..16d824bc23 100644 --- a/manifests/base/crds/kustomization.yaml +++ b/manifests/base/crds/kustomization.yaml @@ -6,3 +6,4 @@ resources: - kubeflow.org_xgboostjobs.yaml - kubeflow.org_mpijobs.yaml - kubeflow.org_paddlejobs.yaml + - kubeflow.org_jaxjobs.yaml diff --git a/manifests/base/rbac/role.yaml b/manifests/base/rbac/role.yaml index 281f220121..ae91d43ba8 100644 --- a/manifests/base/rbac/role.yaml +++ b/manifests/base/rbac/role.yaml @@ -92,6 +92,32 @@ rules: - patch - update - watch +- apiGroups: + - kubeflow.org + resources: + - jaxjobs + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - kubeflow.org + resources: + - jaxjobs/finalizers + verbs: + - update +- apiGroups: + - kubeflow.org + resources: + - jaxjobs/status + verbs: + - get + - patch + - update - apiGroups: - kubeflow.org resources: diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index c8a69845b7..2c381d0cd1 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -4,6 +4,26 @@ kind: ValidatingWebhookConfiguration metadata: name: validating-webhook-configuration webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-jaxjob + failurePolicy: Fail + name: validator.jaxjob.training-operator.kubeflow.org + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - jaxjobs + sideEffects: None - admissionReviewVersions: - v1 clientConfig: diff --git a/manifests/base/webhook/patch.yaml b/manifests/base/webhook/patch.yaml index a02b11bf1c..b103423df2 100644 --- a/manifests/base/webhook/patch.yaml +++ b/manifests/base/webhook/patch.yaml @@ -10,6 +10,9 @@ - op: replace path: /webhooks/3/clientConfig/service/name value: training-operator +- op: replace + path: /webhooks/4/clientConfig/service/name + value: training-operator - op: replace path: /metadata/name value: validator.training-operator.kubeflow.org diff --git a/pkg/controller.v1/jax/envvar.go b/pkg/controller.v1/jax/envvar.go new file mode 100644 index 0000000000..423ea809e7 --- /dev/null +++ b/pkg/controller.v1/jax/envvar.go @@ -0,0 +1,102 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package jax + +import ( + "fmt" + "strconv" + "strings" + + corev1 "k8s.io/api/core/v1" + "k8s.io/utils/ptr" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +type EnvVarGenerator interface { + Generate(job *kubeflowv1.JAXJob) ([]corev1.EnvVar, error) +} + +func setPodEnv(jaxjob *kubeflowv1.JAXJob, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error { + + coordinatorAddr := replicaName(jaxjob.Name, kubeflowv1.JAXJobReplicaTypeWorker, 0) + + coordinatorPort, err := getPortFromJAXJob(jaxjob, kubeflowv1.JAXJobReplicaTypeWorker) + if err != nil { + return err + } + + totalReplicas := getTotalReplicas(jaxjob) + + for i := range podTemplateSpec.Spec.Containers { + + rank, err := strconv.Atoi(index) + if err != nil { + return err + } + // Set PYTHONUNBUFFERED to true, to disable output buffering. + // Ref https://stackoverflow.com/questions/59812009/what-is-the-use-of-pythonunbuffered-in-docker-file. + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "PYTHONUNBUFFERED", + Value: "1", + }) + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "COORDINATOR_PORT", + Value: strconv.Itoa(int(coordinatorPort)), + }) + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "COORDINATOR_ADDRESS", + Value: coordinatorAddr, + }) + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "NUM_PROCESSES", + Value: strconv.Itoa(int(totalReplicas)), + }) + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "PROCESS_ID", + Value: strconv.Itoa(rank), + }) + } + + return nil +} + +func getTotalReplicas(job *kubeflowv1.JAXJob) int { + jobReplicas := 0 + for _, r := range job.Spec.JAXReplicaSpecs { + jobReplicas += int(ptr.Deref[int32](r.Replicas, 0)) + } + return jobReplicas +} + +func replicaName(jobName string, rtype kubeflowv1.ReplicaType, index int) string { + n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + strconv.Itoa(index) + return strings.Replace(n, "/", "-", -1) +} + +func getPortFromJAXJob(job *kubeflowv1.JAXJob, rtype kubeflowv1.ReplicaType) (int32, error) { + containers := job.Spec.JAXReplicaSpecs[rtype].Template.Spec.Containers + for _, container := range containers { + if container.Name == kubeflowv1.JAXJobDefaultContainerName { + ports := container.Ports + for _, port := range ports { + if port.Name == kubeflowv1.JAXJobDefaultPortName { + return port.ContainerPort, nil + } + } + } + } + return -1, fmt.Errorf("port not found") +} diff --git a/pkg/controller.v1/jax/jaxjob_controller.go b/pkg/controller.v1/jax/jaxjob_controller.go new file mode 100644 index 0000000000..9a6566954c --- /dev/null +++ b/pkg/controller.v1/jax/jaxjob_controller.go @@ -0,0 +1,488 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jax + +import ( + "context" + "fmt" + "strings" + "time" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + "github.com/kubeflow/training-operator/pkg/controller.v1/common" + "github.com/kubeflow/training-operator/pkg/controller.v1/control" + "github.com/kubeflow/training-operator/pkg/controller.v1/expectation" + commonutil "github.com/kubeflow/training-operator/pkg/util" + + "github.com/go-logr/logr" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/informers" + kubeclientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/source" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" +) + +const ( + controllerName = "jaxjob-controller" +) + +// NewReconciler creates a JAXJob Reconciler +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *JAXJobReconciler { + r := &JAXJobReconciler{ + client: mgr.GetClient(), + scheme: mgr.GetScheme(), + recorder: mgr.GetEventRecorderFor(controllerName), + apiReader: mgr.GetAPIReader(), + log: ctrl.Log.WithName(controllerName), + } + + // Create clients + cfg := mgr.GetConfig() + kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) + sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) + priorityClassInformer := sharedInformers.Scheduling().V1().PriorityClasses() + + // Initialize common job controller + r.JobController = common.JobController{ + Controller: r, + Expectations: expectation.NewControllerExpectations(), + WorkQueue: &util.FakeWorkQueue{}, + Recorder: r.recorder, + KubeClientSet: kubeClientSet, + PriorityClassLister: priorityClassInformer.Lister(), + PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, + PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, + ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, + } + + gangSchedulingSetupFunc(&r.JobController) + + return r +} + +// JAXJobReconciler reconciles a JAXJob object +type JAXJobReconciler struct { + common.JobController + client client.Client + scheme *runtime.Scheme + log logr.Logger + recorder record.EventRecorder + apiReader client.Reader +} + +//+kubebuilder:rbac:groups=kubeflow.org,resources=jaxjobs,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=kubeflow.org,resources=jaxjobs/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=kubeflow.org,resources=jaxjobs/finalizers,verbs=update +//+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete +//+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=scheduling.x-k8s.io,resources=podgroups,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups="",resources=events,verbs=get;list;watch;create;update;patch;delete + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +// the JAXJob object against the actual cluster state, and then +// perform operations to make the cluster state reflect the state specified by +// the user. +// +// For more details, check Reconcile and its Result here: +// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.8.3/pkg/reconcile +func (r *JAXJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + + jaxjob := &kubeflowv1.JAXJob{} + err := r.client.Get(ctx, req.NamespacedName, jaxjob) + if err != nil { + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + // log := ctrl.LoggerFrom(ctx).WithValues("jaxjob", klog.KObj(&jaxjob)) + // ctrl.LoggerInto(ctx, log) + // log.V(2).Info("Reconciling JAXJob") + + // Check if reconciliation is needed + jobKey, err := common.KeyFunc(jaxjob) + if err != nil { + utilruntime.HandleError(fmt.Errorf("couldn't get jobKey for job object %#v: %v", jaxjob, err)) + } + + replicaTypes := util.GetReplicaTypes(jaxjob.Spec.JAXReplicaSpecs) + needReconcile := util.SatisfiedExpectations(r.Expectations, jobKey, replicaTypes) + + if !needReconcile || jaxjob.GetDeletionTimestamp() != nil { + r.log.Info("reconcile cancelled, job does not need to do reconcile or has been deleted", + "sync", needReconcile, "deleted", jaxjob.GetDeletionTimestamp() != nil) + return ctrl.Result{}, nil + } + + // Set default priorities to jax job + r.scheme.Default(jaxjob) + + // Use common to reconcile the job related pod and service + err = r.ReconcileJobs(jaxjob, jaxjob.Spec.JAXReplicaSpecs, jaxjob.Status, &jaxjob.Spec.RunPolicy) + if err != nil { + r.log.Error(err, "Reconcile JAXJob error") + return ctrl.Result{}, err + } + t, err := util.DurationUntilExpireTime(&jaxjob.Spec.RunPolicy, jaxjob.Status) + if err != nil { + logrus.Warnf("Reconcile JAXJob error %v", err) + return ctrl.Result{}, err + } + if t >= 0 { + return ctrl.Result{Requeue: true, RequeueAfter: t}, nil + } + + return ctrl.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager. +func (r *JAXJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads int) error { + c, err := controller.New(r.ControllerName(), mgr, controller.Options{ + Reconciler: r, + MaxConcurrentReconciles: controllerThreads, + }) + if err != nil { + return err + } + + // using onOwnerCreateFunc is easier to set defaults + if err = c.Watch(source.Kind(mgr.GetCache(), &kubeflowv1.JAXJob{}), &handler.EnqueueRequestForObject{}, + predicate.Funcs{CreateFunc: r.onOwnerCreateFunc()}, + ); err != nil { + return err + } + + // eventHandler for owned object + eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.JAXJob{}, handler.OnlyControllerOwner()) + predicates := predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + } + // Create generic predicates + genericPredicates := predicate.Funcs{ + CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), + UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), + DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), + } + // inject watching for job related pod + if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), eventHandler, predicates); err != nil { + return err + } + // inject watching for job related service + if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Service{}), eventHandler, predicates); err != nil { + return err + } + // skip watching volcano PodGroup if volcano PodGroup is not installed + if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version); err == nil { + // inject watching for job related volcano PodGroup + if err = c.Watch(source.Kind(mgr.GetCache(), &v1beta1.PodGroup{}), eventHandler, genericPredicates); err != nil { + return err + } + } + // skip watching scheduler-plugins PodGroup if scheduler-plugins PodGroup is not installed + if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil { + // inject watching for job related scheduler-plugins PodGroup + if err = c.Watch(source.Kind(mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}), eventHandler, genericPredicates); err != nil { + return err + } + } + return nil +} + +func (r *JAXJobReconciler) ControllerName() string { + return controllerName +} + +func (r *JAXJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind { + return kubeflowv1.GroupVersion.WithKind(kubeflowv1.JAXJobKind) +} + +func (r *JAXJobReconciler) GetAPIGroupVersion() schema.GroupVersion { + return kubeflowv1.GroupVersion +} + +func (r *JAXJobReconciler) GetGroupNameLabelValue() string { + return kubeflowv1.GroupVersion.Group +} + +func (r *JAXJobReconciler) GetFrameworkName() string { + return kubeflowv1.JAXJobFrameworkName +} + +func (r *JAXJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) { + job := &kubeflowv1.JAXJob{} + err := r.client.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) + if err != nil { + if errors.IsNotFound(err) { + logrus.Error(err, "jax job not found", "namespace", namespace, "name", name) + } else { + logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) + } + return nil, err + } + return job, nil +} + +func (r *JAXJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) { + job := &kubeflowv1.JAXJob{} + + err := r.apiReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) + if err != nil { + if errors.IsNotFound(err) { + logrus.Error(err, "jax job not found", "namespace", namespace, "name", name) + } else { + logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) + } + return nil, err + } + return job, nil +} + +func (r *JAXJobReconciler) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) { + job, err := meta.Accessor(obj) + if err != nil { + return nil, err + } + + // List all pods to include those that don't match the selector anymore + // but have a ControllerRef pointing to this controller. + podlist := &corev1.PodList{} + err = r.client.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace())) + if err != nil { + return nil, err + } + + return util.JobControlledPodList(podlist.Items, job), nil +} + +func (r *JAXJobReconciler) GetServicesForJob(obj interface{}) ([]*corev1.Service, error) { + job, err := meta.Accessor(obj) + if err != nil { + return nil, err + } + + // List all pods to include those that don't match the selector anymore + // but have a ControllerRef pointing to this controller. + serviceList := &corev1.ServiceList{} + err = r.client.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace())) + if err != nil { + return nil, err + } + + ret := util.ConvertServiceList(serviceList.Items) + return ret, nil +} + +func (r *JAXJobReconciler) DeleteJob(job interface{}) error { + jaxjob, ok := job.(*kubeflowv1.JAXJob) + if !ok { + return fmt.Errorf("%+v is not a type of JAXJob", job) + } + if err := r.client.Delete(context.Background(), jaxjob); err != nil { + r.recorder.Eventf(jaxjob, corev1.EventTypeWarning, control.FailedDeletePodReason, "Error deleting: %v", err) + logrus.Error(err, "failed to delete job", "namespace", jaxjob.Namespace, "name", jaxjob.Name) + return err + } + r.recorder.Eventf(jaxjob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", jaxjob.Name) + logrus.Info("job deleted", "namespace", jaxjob.Namespace, "name", jaxjob.Name) + trainingoperatorcommon.DeletedJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName()) + return nil +} + +func (r *JAXJobReconciler) GenLabelSelector(jobName string, + rtype kubeflowv1.ReplicaType) *metav1.LabelSelector { + labels := r.GenLabels(jobName) + labels[kubeflowv1.ReplicaTypeLabel] = strings.ToLower(string(rtype)) + + return &metav1.LabelSelector{ + MatchLabels: labels, + } +} + +// UpdateJobStatus updates the job status and job conditions +func (r *JAXJobReconciler) UpdateJobStatus(job interface{}, + replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, + jobStatus *kubeflowv1.JobStatus) error { + jaxjob, ok := job.(*kubeflowv1.JAXJob) + if !ok { + return fmt.Errorf("%+v is not a type of JAXJob", job) + } + jaxjobKey, err := common.KeyFunc(jaxjob) + if err != nil { + utilruntime.HandleError(fmt.Errorf("couldn't get key for jaxjob object %#v: %v", jaxjob, err)) + return err + } + + logger := commonutil.LoggerForJob(jaxjob) + + // Set StartTime. + if jobStatus.StartTime == nil { + now := metav1.Now() + jobStatus.StartTime = &now + // enqueue a sync to check if job past ActiveDeadlineSeconds + if jaxjob.Spec.RunPolicy.ActiveDeadlineSeconds != nil { + logger.Infof("Job with ActiveDeadlineSeconds will sync after %d seconds", *jaxjob.Spec.RunPolicy.ActiveDeadlineSeconds) + r.WorkQueue.AddAfter(jaxjobKey, time.Duration(*jaxjob.Spec.RunPolicy.ActiveDeadlineSeconds)*time.Second) + } + } + + for rtype, spec := range replicas { + status := jobStatus.ReplicaStatuses[rtype] + // Generate the label selector. + status.Selector = metav1.FormatLabelSelector(r.GenLabelSelector(jaxjob.Name, rtype)) + + succeeded := status.Succeeded + expected := *(spec.Replicas) - succeeded + running := status.Active + failed := status.Failed + specReplicas := *spec.Replicas + + logrus.Infof("JAXJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d, failed=%d, Replicas=%d", + jaxjob.Name, rtype, expected, running, succeeded, failed, specReplicas) + + if rtype == kubeflowv1.JAXJobReplicaTypeWorker { + if expected == 0 { + msg := fmt.Sprintf("JAXJob %s/%s successfully completed.", + jaxjob.Namespace, jaxjob.Name) + r.recorder.Event(jaxjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSucceededReason), msg) + if jobStatus.CompletionTime == nil { + now := metav1.Now() + jobStatus.CompletionTime = &now + } + commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSucceededReason), msg) + trainingoperatorcommon.SuccessfulJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName()) + } else if running > 0 { + // Some workers are still running, leave a running condition. + msg := fmt.Sprintf("JAXJob %s/%s is running.", + jaxjob.Namespace, jaxjob.Name) + commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), msg) + } + } + + if failed > 0 && (specReplicas > succeeded+running) { + if spec.RestartPolicy != kubeflowv1.RestartPolicyNever { + msg := fmt.Sprintf("JAXJob %s is restarting because %d %s replica(s) failed.", jaxjob.Name, failed, rtype) + r.Recorder.Event(jaxjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRestartingReason), msg) + commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRestarting, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRestartingReason), msg) + trainingoperatorcommon.RestartedJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName()) + } else { + msg := fmt.Sprintf("JAXJob %s is failed because %d %s replica(s) failed.", jaxjob.Name, failed, rtype) + r.Recorder.Event(jaxjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobFailedReason), msg) + if jobStatus.CompletionTime == nil { + now := metav1.Now() + jobStatus.CompletionTime = &now + } + commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobFailed, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobFailedReason), msg) + trainingoperatorcommon.FailedJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName()) + } + } + } + return nil +} + +// UpdateJobStatusInApiServer updates the job status in to cluster. +func (r *JAXJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *kubeflowv1.JobStatus) error { + if jobStatus.ReplicaStatuses == nil { + jobStatus.ReplicaStatuses = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaStatus{} + } + + jaxjob, ok := job.(*kubeflowv1.JAXJob) + trainingoperatorcommon.ClearGeneratedFields(&jaxjob.ObjectMeta) + if !ok { + return fmt.Errorf("%+v is not a type of JAXJob", job) + } + + // Job status passed in differs with status in job, update in basis of the passed in one. + if !equality.Semantic.DeepEqual(&jaxjob.Status, jobStatus) { + jaxjob = jaxjob.DeepCopy() + jaxjob.Status = *jobStatus.DeepCopy() + } + + result := r.client.Status().Update(context.Background(), jaxjob) + + if result != nil { + r.log.WithValues("jaxjob", types.NamespacedName{ + Namespace: jaxjob.GetNamespace(), + Name: jaxjob.GetName(), + }) + return result + } + + return nil +} + +// SetClusterSpec sets the cluster spec and init container for the pod +func (r *JAXJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + jaxjob, ok := job.(*kubeflowv1.JAXJob) + if !ok { + return fmt.Errorf("%+v is not a type of JAXJob", job) + } + if err := setPodEnv(jaxjob, podTemplate, rtype, index); err != nil { + return err + } + return nil +} + +func (r *JAXJobReconciler) GetDefaultContainerName() string { + return kubeflowv1.JAXJobDefaultContainerName +} + +func (r *JAXJobReconciler) GetDefaultContainerPortName() string { + return kubeflowv1.JAXJobDefaultPortName +} + +func (r *JAXJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, + rtype kubeflowv1.ReplicaType, index int) bool { + return index == 0 +} + +// onOwnerCreateFunc modify creation condition. +func (r *JAXJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool { + return func(e event.CreateEvent) bool { + jaxjob, ok := e.Object.(*kubeflowv1.JAXJob) + if !ok { + return true + } + r.scheme.Default(jaxjob) + msg := fmt.Sprintf("JAXJob %s is created.", e.Object.GetName()) + logrus.Info(msg) + trainingoperatorcommon.CreatedJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName()) + commonutil.UpdateJobConditions(&jaxjob.Status, kubeflowv1.JobCreated, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), msg) + return true + } +} diff --git a/pkg/controller.v1/jax/jaxjob_controller_suite_test.go b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go new file mode 100644 index 0000000000..a9471d9c83 --- /dev/null +++ b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go @@ -0,0 +1,125 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jax + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "path/filepath" + "testing" + "time" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/controller.v1/common" + jaxwebhook "github.com/kubeflow/training-operator/pkg/webhooks/jax" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/envtest" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + //+kubebuilder:scaffold:imports +) + +var ( + testK8sClient client.Client + testEnv *envtest.Environment + testCtx context.Context + testCancel context.CancelFunc +) + +func TestAPIs(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecs(t, "Controller Suite") +} + +var _ = BeforeSuite(func() { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + testCtx, testCancel = context.WithCancel(context.TODO()) + + By("bootstrapping test environment") + testEnv = &envtest.Environment{ + CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, + ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")}, + }, + } + + cfg, err := testEnv.Start() + Expect(err).NotTo(HaveOccurred()) + Expect(cfg).NotTo(BeNil()) + + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = kubeflowv1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:scheme + + testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + Expect(err).NotTo(HaveOccurred()) + Expect(testK8sClient).NotTo(BeNil()) + + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Metrics: metricsserver.Options{ + BindAddress: "0", + }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), + }) + Expect(err).NotTo(HaveOccurred()) + + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + r := NewReconciler(mgr, gangSchedulingSetupFunc) + + Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(jaxwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + err = mgr.Start(testCtx) + Expect(err).ToNot(HaveOccurred(), "failed to run manager") + }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) +}) + +var _ = AfterSuite(func() { + By("tearing down the test environment") + testCancel() + err := testEnv.Stop() + Expect(err).NotTo(HaveOccurred()) +}) diff --git a/pkg/controller.v1/jax/jaxjob_controller_test.go b/pkg/controller.v1/jax/jaxjob_controller_test.go new file mode 100644 index 0000000000..7a6255aef0 --- /dev/null +++ b/pkg/controller.v1/jax/jaxjob_controller_test.go @@ -0,0 +1,316 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jax + +import ( + "context" + "fmt" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + commonutil "github.com/kubeflow/training-operator/pkg/util" + "github.com/kubeflow/training-operator/pkg/util/testutil" +) + +var _ = Describe("JAXJob controller", func() { + // Define utility constants for object names. + const ( + expectedPort = int32(6666) + ) + + Context("When creating the JAXJob", func() { + const name = "test-job" + var ( + ns *corev1.Namespace + job *kubeflowv1.JAXJob + jobKey types.NamespacedName + worker0Key types.NamespacedName + ctx = context.Background() + ) + BeforeEach(func() { + ns = &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "jax-test-", + }, + } + Expect(testK8sClient.Create(ctx, ns)).Should(Succeed()) + + job = &kubeflowv1.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns.Name, + }, + } + jobKey = client.ObjectKeyFromObject(job) + + worker0Key = types.NamespacedName{ + Name: fmt.Sprintf("%s-worker-0", name), + Namespace: ns.Name, + } + job.Spec.JAXReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ + kubeflowv1.JAXJobReplicaTypeWorker: { + Replicas: ptr.To[int32](2), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Image: "test-image", + Name: kubeflowv1.JAXJobDefaultContainerName, + Ports: []corev1.ContainerPort{ + { + Name: kubeflowv1.JAXJobDefaultPortName, + ContainerPort: expectedPort, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + } + }) + AfterEach(func() { + Expect(testK8sClient.Delete(ctx, job)).Should(Succeed()) + Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed()) + }) + + It("Shouldn't create resources if JAXJob is suspended", func() { + By("By creating a new JAXJob with suspend=true") + job.Spec.RunPolicy.Suspend = ptr.To(true) + job.Spec.JAXReplicaSpecs[kubeflowv1.JAXJobReplicaTypeWorker].Replicas = ptr.To[int32](1) + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.JAXJob{} + workerPod := &corev1.Pod{} + workerSvc := &corev1.Service{} + + By("Checking created JAXJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + By("Checking created JAXJob has a nil startTime") + Consistently(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.ConsistentDuration, testutil.Interval).Should(BeNil()) + + By("Checking if the pods and services aren't created") + Consistently(func() bool { + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the JAXJob has suspended condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + }, + }, testutil.IgnoreJobConditionsTimes)) + }) + + It("Should delete resources after JAXJob is suspended; Should resume JAXJob after JAXJob is unsuspended", func() { + By("By creating a new JAXJob") + job.Spec.JAXReplicaSpecs[kubeflowv1.JAXJobReplicaTypeWorker].Replicas = ptr.To[int32](1) + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.JAXJob{} + workerPod := &corev1.Pod{} + workerSvc := &corev1.Service{} + + // We'll need to retry getting this newly created JAXJob, given that creation may not immediately happen. + By("Checking created JAXJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + var startTimeBeforeSuspended *metav1.Time + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + startTimeBeforeSuspended = created.Status.StartTime + return startTimeBeforeSuspended + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Checking the created pods and services") + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + By("Updating the pod's phase with Running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking the JAXJob's condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Updating the JAXJob with suspend=true") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = ptr.To(true) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the pods and services are removed") + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the JAXJob has a suspended condition") + Eventually(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.JAXJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.JAXJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionFalse, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + Status: corev1.ConditionTrue, + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Unsuspending the JAXJob") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = ptr.To(false) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Check if the pods and services are created") + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerPod) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerSvc) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + + By("Updating Pod's condition with running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the JAXJob has resumed conditions") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobResumedReason), + Message: fmt.Sprintf("JAXJob %s is resumed.", name), + Status: corev1.ConditionFalse, + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Checking if the startTime is updated") + Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended)) + }) + }) +}) diff --git a/pkg/controller.v1/register_controller.go b/pkg/controller.v1/register_controller.go index 114271cc8e..ea099ced14 100644 --- a/pkg/controller.v1/register_controller.go +++ b/pkg/controller.v1/register_controller.go @@ -20,6 +20,7 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + jaxcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/jax" mpicontroller "github.com/kubeflow/training-operator/pkg/controller.v1/mpi" paddlecontroller "github.com/kubeflow/training-operator/pkg/controller.v1/paddlepaddle" pytorchcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/pytorch" @@ -49,6 +50,9 @@ var SupportedSchemeReconciler = map[string]ReconcilerSetupFunc{ kubeflowv1.PaddleJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { return paddlecontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, + kubeflowv1.JAXJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return jaxcontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) + }, } type EnabledSchemes []string diff --git a/pkg/webhooks/jax/jaxjob_webhook.go b/pkg/webhooks/jax/jaxjob_webhook.go new file mode 100644 index 0000000000..12888b3d3c --- /dev/null +++ b/pkg/webhooks/jax/jaxjob_webhook.go @@ -0,0 +1,124 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jax + +import ( + "context" + "fmt" + "slices" + "strings" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + jaxReplicaSpecPath = specPath.Child("jaxReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.JAXJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-jaxjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=jaxjobs,verbs=create;update,versions=v1,name=validator.jaxjob.training-operator.kubeflow.org,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.JAXJob) + log := ctrl.LoggerFrom(ctx).WithName("jaxjob-webhook") + log.V(5).Info("Validating create", "jaxJob", klog.KObj(job)) + return nil, validateJAXJob(job).ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _ runtime.Object, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.JAXJob) + log := ctrl.LoggerFrom(ctx).WithName("jaxjob-webhook") + log.V(5).Info("Validating update", "jaxJob", klog.KObj(job)) + return nil, validateJAXJob(job).ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validateJAXJob(job *trainingoperator.JAXJob) field.ErrorList { + var allErrs field.ErrorList + if errors := apimachineryvalidation.NameIsDNS1035Label(job.ObjectMeta.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + + allErrs = append(allErrs, validateSpec(job.Spec)...) + return allErrs +} + +func validateSpec(spec trainingoperator.JAXJobSpec) field.ErrorList { + return validateJAXReplicaSpecs(spec.JAXReplicaSpecs) +} + +func validateJAXReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(jaxReplicaSpecPath, "must be required")) + } + for rType, rSpec := range rSpecs { + rolePath := jaxReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validRoleTypes := []trainingoperator.ReplicaType{ + trainingoperator.JAXJobReplicaTypeWorker, + } + if !slices.Contains(validRoleTypes, rType) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.JAXJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "jax" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.JAXJobDefaultContainerName))) + } + } + return allErrs +} diff --git a/pkg/webhooks/jax/jaxjob_webhook_test.go b/pkg/webhooks/jax/jaxjob_webhook_test.go new file mode 100644 index 0000000000..ed3bd2bd1b --- /dev/null +++ b/pkg/webhooks/jax/jaxjob_webhook_test.go @@ -0,0 +1,198 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jax + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateV1JAXJob(t *testing.T) { + validJAXReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.JAXJobReplicaTypeWorker: { + Replicas: ptr.To[int32](1), + RestartPolicy: trainingoperator.RestartPolicyOnFailure, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "jax", + Image: "docker.io/sandipanify/jaxgoogle:latest", + Ports: []corev1.ContainerPort{{ + Name: "jaxjob-port", + ContainerPort: 6666, + }}, + ImagePullPolicy: corev1.PullAlways, + Command: []string{ + "python", + "train.py", + }, + }}, + }, + }, + }, + } + + testCases := map[string]struct { + jaxJob *trainingoperator.JAXJob + wantErr field.ErrorList + }{ + "valid JAXJob": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: validJAXReplicaSpecs, + }, + }, + }, + "jaxJob name does not meet DNS1035": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "0-test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: validJAXReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "no containers": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.JAXJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(jaxReplicaSpecPath. + Key(string(trainingoperator.JAXJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + field.Required(jaxReplicaSpecPath. + Key(string(trainingoperator.JAXJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "image is empty": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.JAXJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "jax", + Image: "", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(jaxReplicaSpecPath. + Key(string(trainingoperator.JAXJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"). + Index(0). + Child("image"), ""), + }, + }, + "jaxJob default container name doesn't present": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.JAXJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "", + Image: "gcr.io/kubeflow-ci/jaxjob-simple_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(jaxReplicaSpecPath. + Key(string(trainingoperator.JAXJobReplicaTypeWorker)). + Child("template"). + Child("spec"). + Child("containers"), ""), + }, + }, + "replicaSpec is nil": { + jaxJob: &trainingoperator.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.JAXJobSpec{ + JAXReplicaSpecs: nil, + }, + }, + wantErr: field.ErrorList{ + field.Required(jaxReplicaSpecPath, ""), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := validateJAXJob(tc.jaxJob) + if diff := cmp.Diff(tc.wantErr, got, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go index 29ad08e2fd..d1dd2b2f8e 100644 --- a/pkg/webhooks/webhooks.go +++ b/pkg/webhooks/webhooks.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/jax" "github.com/kubeflow/training-operator/pkg/webhooks/paddlepaddle" "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" "github.com/kubeflow/training-operator/pkg/webhooks/tensorflow" @@ -35,6 +36,7 @@ var ( trainingoperator.XGBoostJobKind: xgboost.SetupWebhook, trainingoperator.MPIJobKind: scaffold, trainingoperator.PaddleJobKind: paddlepaddle.SetupWebhook, + trainingoperator.JAXJobKind: jax.SetupWebhook, } ) diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index acade226c9..1626f18820 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -971,6 +971,8 @@ def get_job_pods( For PaddleJob one of `master` or `worker`. + For JAXJob `worker`. + replica_index: Index for the Job replica. timeout: Kubernetes API server timeout in seconds to execute the request. @@ -992,6 +994,7 @@ def get_job_pods( and replica_type not in constants.XGBOOSTJOB_REPLICA_TYPES and replica_type not in constants.MPIJOB_REPLICA_TYPES and replica_type not in constants.PADDLEJOB_REPLICA_TYPES + and replica_type not in constants.JAXJOB_REPLICA_TYPES ): raise ValueError( f"TFJob replica type must be one of {constants.TFJOB_REPLICA_TYPES}\n" @@ -999,6 +1002,7 @@ def get_job_pods( f"XGBoostJob replica type must be one of {constants.XGBOOSTJOB_REPLICA_TYPES}\n" f"MPIJob replica type must be one of {constants.MPIJOB_REPLICA_TYPES}\n" f"PaddleJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}" + f"JAXJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}" ) label_selector = f"{constants.JOB_NAME_LABEL}={name}" @@ -1058,6 +1062,8 @@ def get_job_pod_names( For PaddleJob one of `master` or `worker`. + For JAXJob `worker`. + replica_index: Index for the Job replica. timeout: Kubernetes API server timeout in seconds to execute the request. @@ -1118,6 +1124,8 @@ def get_job_logs( For MPIJob one of `launcher` or `worker`. For PaddleJob one of `master` or `worker`. + + For JAXJob `worker`. replica_index: Optional, index for the Job replica. container: Pod container to get the logs. follow: Whether to follow the log stream of the pod and print logs to StdOut. diff --git a/sdk/python/kubeflow/training/constants/constants.py b/sdk/python/kubeflow/training/constants/constants.py index 7b5759cef0..0bb6fe495e 100644 --- a/sdk/python/kubeflow/training/constants/constants.py +++ b/sdk/python/kubeflow/training/constants/constants.py @@ -138,6 +138,13 @@ "docker.io/paddlepaddle/paddle:2.4.0rc0-gpu-cuda11.2-cudnn8.1-trt8.0" ) +# JAXJob constants +JAXJOB_KIND = "JAXJob" +JAXJOB_MODEL = "KubeflowOrgV1JAXJob" +JAXJOB_PLURAL = "jaxjobs" +JAXJOB_CONTAINER = "jax" +JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower() +JAXJOB_BASE_IMAGE = "kubeflow/jaxjob-simple:latest" # Dictionary to get plural, model, and container for each Job kind. JOB_PARAMETERS = { @@ -171,6 +178,12 @@ "container": PADDLEJOB_CONTAINER, "base_image": PADDLEJOB_BASE_IMAGE, }, + JAXJOB_KIND: { + "model": JAXJOB_MODEL, + "plural": JAXJOB_PLURAL, + "container": JAXJOB_CONTAINER, + "base_image": "JAXJOB_BASE_IMAGE", + }, } # Tuple of all Job models. @@ -183,4 +196,5 @@ models.KubeflowOrgV1XGBoostJob, models.KubeflowOrgV1MPIJob, models.KubeflowOrgV1PaddleJob, + models.KubeflowOrgV1JAXJob, ] diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py new file mode 100644 index 0000000000..cf350f1c11 --- /dev/null +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -0,0 +1,161 @@ +# Copyright 2024 kubeflow.org. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import pytest +from typing import Optional + +from kubernetes.client import V1PodTemplateSpec +from kubernetes.client import V1ObjectMeta +from kubernetes.client import V1PodSpec +from kubernetes.client import V1Container +from kubernetes.client import V1ResourceRequirements + +from kubeflow.training import TrainingClient +from kubeflow.training import KubeflowOrgV1ReplicaSpec +from kubeflow.training import KubeflowOrgV1JAXJob +from kubeflow.training import KubeflowOrgV1JAXJobSpec +from kubeflow.training import KubeflowOrgV1RunPolicy +from kubeflow.training import KubeflowOrgV1SchedulingPolicy +from kubeflow.training.constants import constants + +import test.e2e.utils as utils +from test.e2e.constants import TEST_GANG_SCHEDULER_NAME_ENV_KEY +from test.e2e.constants import GANG_SCHEDULERS, NONE_GANG_SCHEDULERS + +logging.basicConfig(format="%(message)s") +logging.getLogger("kubeflow.training.api.training_client").setLevel(logging.DEBUG) + +TRAINING_CLIENT = TrainingClient(job_kind=constants.JAXJOB_KIND) +JOB_NAME = "jaxjob-cpu-ci-test" +CONTAINER_NAME = "jax" +GANG_SCHEDULER_NAME = os.getenv(TEST_GANG_SCHEDULER_NAME_ENV_KEY, "") + + +@pytest.mark.skipif( + GANG_SCHEDULER_NAME in NONE_GANG_SCHEDULERS, + reason="For gang-scheduling", +) +def test_sdk_e2e_with_gang_scheduling(job_namespace): + container = generate_container() + + worker = KubeflowOrgV1ReplicaSpec( + replicas=2, + restart_policy="OnFailure", + template=V1PodTemplateSpec( + metadata=V1ObjectMeta( + annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} + ), + spec=V1PodSpec( + scheduler_name=utils.get_pod_spec_scheduler_name(GANG_SCHEDULER_NAME), + containers=[container], + ), + ), + ) + + unschedulable_jaxjob = generate_jaxjob( + job_namespace, worker, KubeflowOrgV1SchedulingPolicy(min_available=10) + ) + schedulable_jaxjob = generate_jaxjob( + job_namespace, worker, KubeflowOrgV1SchedulingPolicy(min_available=2) + ) + + TRAINING_CLIENT.create_job(job=unschedulable_jaxjob, namespace=job_namespace) + logging.info(f"List of created {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_unschedulable_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + TRAINING_CLIENT.update_job(schedulable_jaxjob, JOB_NAME, job_namespace) + logging.info(f"List of updated {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + + +@pytest.mark.skipif( + GANG_SCHEDULER_NAME in GANG_SCHEDULERS, + reason="For plain scheduling", +) +def test_sdk_e2e(job_namespace): + container = generate_container() + + worker = KubeflowOrgV1ReplicaSpec( + replicas=2, + restart_policy="OnFailure", + template=V1PodTemplateSpec( + metadata=V1ObjectMeta( + annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} + ), + spec=V1PodSpec(containers=[container]), + ), + ) + + jaxjob = generate_jaxjob(job_namespace, worker) + + TRAINING_CLIENT.create_job(job=jaxjob, namespace=job_namespace) + logging.info(f"List of created {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + + +def generate_jaxjob( + job_namespace: str, + worker: KubeflowOrgV1ReplicaSpec, + scheduling_policy: Optional[KubeflowOrgV1SchedulingPolicy] = None, +) -> KubeflowOrgV1JAXJob: + return KubeflowOrgV1JAXJob( + api_version=constants.API_VERSION, + kind=constants.JAXJOB_KIND, + metadata=V1ObjectMeta(name=JOB_NAME, namespace=job_namespace), + spec=KubeflowOrgV1JAXJobSpec( + run_policy=KubeflowOrgV1RunPolicy( + scheduling_policy=scheduling_policy, + clean_pod_policy="None", + ), + jax_replica_specs={"Worker": worker}, + ), + ) + + +def generate_container() -> V1Container: + return V1Container( + name=CONTAINER_NAME, + image="docker.io/sandipanify/jaxgoogle:latest", + command=["python", "train.py"], + resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), + )