diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index a21beecebf..2ab3d79c1d 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -191,7 +191,7 @@ def create_job( job = utils.get_pytorchjob_template( name=name, namespace=namespace, - pod_template_spec=pod_template_spec, + worker_pod_template_spec=pod_template_spec, num_worker_replicas=num_worker_replicas, ) else: diff --git a/sdk/python/kubeflow/training/utils/utils.py b/sdk/python/kubeflow/training/utils/utils.py index 347e94f9ba..b3b7ca1d28 100644 --- a/sdk/python/kubeflow/training/utils/utils.py +++ b/sdk/python/kubeflow/training/utils/utils.py @@ -117,6 +117,33 @@ def get_script_for_python_packages( return script_for_python_packages +def get_container_spec( + name: str, + image: str, + args: Optional[List[str]] = None, + resources: Optional[models.V1ResourceRequirements] = None, + volume_mounts: Optional[models.V1VolumeMount] = None, +) -> models.V1Container: + """ + get container spec for given name and image. + """ + if name is None or image is None: + raise ValueError("container name or image cannot be none") + + container_spec = models.V1Container(name=name, image=image) + + if args: + container_spec.args = args + + if resources: + container_spec.resources = resources + + if volume_mounts: + container_spec.volume_mounts = volume_mounts + + return container_spec + + def get_pod_template_spec( job_kind: str, base_image: Optional[str] = None, @@ -124,6 +151,9 @@ def get_pod_template_spec( parameters: Optional[Dict[str, Any]] = None, packages_to_install: Optional[List[str]] = None, pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL, + init_containers_spec: Optional[List[models.V1Container]] = None, + containers_spec: Optional[List[models.V1Container]] = None, + volumes_spec: Optional[List[models.V1Volume]] = None, ): """ Get Pod template spec for the given function and base image. @@ -141,7 +171,7 @@ def get_pod_template_spec( ), spec=models.V1PodSpec( containers=[ - models.V1Container( + get_container_spec( name=constants.JOB_PARAMETERS[job_kind]["container"], image=base_image, ) @@ -149,6 +179,13 @@ def get_pod_template_spec( ), ) + if containers_spec: + pod_template_spec.spec.containers = containers_spec + if init_containers_spec: + pod_template_spec.spec.init_containers = init_containers_spec + if volumes_spec: + pod_template_spec.spec.volumes = volumes_spec + # If Training function is set, convert function to container execution script. if train_func is not None: # Check if function is callable. @@ -261,14 +298,17 @@ def get_tfjob_template( def get_pytorchjob_template( name: str, namespace: str, - pod_template_spec: models.V1PodTemplateSpec, + master_pod_template_spec: models.V1PodTemplateSpec = None, + worker_pod_template_spec: models.V1PodTemplateSpec = None, num_worker_replicas: Optional[int] = None, + num_procs_per_worker: Optional[int] = None, + elastic_policy: Optional[models.KubeflowOrgV1ElasticPolicy] = None, ): # Check if at least one replica is set. # TODO (andreyvelich): Remove this check once we have CEL validation. # Ref: https://github.com/kubeflow/training-operator/issues/1708 - if num_worker_replicas is None: - raise ValueError("At least one Worker replica for PyTorchJob must be set") + if num_worker_replicas is None and master_pod_template_spec is None: + raise ValueError("At least one replica for PyTorchJob must be set") # Create PyTorchJob template. pytorchjob = models.KubeflowOrgV1PyTorchJob( @@ -281,21 +321,25 @@ def get_pytorchjob_template( ), ) - # Add Master and Worker replicas to the PyTorchJob. - pytorchjob.spec.pytorch_replica_specs[ - constants.REPLICA_TYPE_MASTER - ] = models.KubeflowOrgV1ReplicaSpec( - replicas=1, - template=pod_template_spec, - ) + if num_procs_per_worker > 0: + pytorchjob.spec.nproc_per_node = num_procs_per_worker + if elastic_policy: + pytorchjob.spec.elastic_policy = elastic_policy - # If number of Worker replicas is 1, PyTorchJob uses only Master replica. - if num_worker_replicas != 1: + if master_pod_template_spec: + pytorchjob.spec.pytorch_replica_specs[ + constants.REPLICA_TYPE_MASTER + ] = models.KubeflowOrgV1ReplicaSpec( + replicas=1, + template=master_pod_template_spec, + ) + + if num_worker_replicas >= 1: pytorchjob.spec.pytorch_replica_specs[ constants.REPLICA_TYPE_WORKER ] = models.KubeflowOrgV1ReplicaSpec( replicas=num_worker_replicas, - template=pod_template_spec, + template=worker_pod_template_spec, ) return pytorchjob