Skip to content

Commit

Permalink
code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Dec 21, 2023
1 parent 9cb8c9b commit 3c9d65e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 44 deletions.
70 changes: 40 additions & 30 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
)
from kubeflow.storage_init_container.s3 import S3DatasetParams


from typing import Union
from kubeflow.storage_init_container.hugging_face import (
HuggingFaceModelParams,
HuggingFaceTrainParams,
)
from kubeflow.storage_init_container.s3 import S3DatasetParams

from typing import Union
from kubeflow.storage_init_container.hugging_face import (
HuggingFaceModelParams,
HuggingFaceTrainParams,
HfDatasetParams,
)
from kubeflow.storage_init_container.s3 import S3DatasetParams

logger = logging.getLogger(__name__)

status_logger = utils.StatusLogger(
Expand Down Expand Up @@ -97,7 +113,7 @@ def train(
num_procs_per_worker: int = 1,
storage_config: Dict[Literal["size", "storage_class"], str] = None,
model_provider_parameters: HuggingFaceModelParams = None,
dataset_provider_parameters: S3DatasetParams = None,
dataset_provider_parameters: Union[HfDatasetParams, S3DatasetParams] = None,
train_parameters: HuggingFaceTrainParams = None,
resources_per_worker: Dict[Literal["gpu", "cpu", "memory"], any] = None,
):
Expand All @@ -116,40 +132,36 @@ def train(
raise ValueError("One of the required parameters is None")

try:
self.core_api.create_namespace(
body=utils.get_namespace_spec(namespace=namespace)
self.core_api.create_namespaced_persistent_volume_claim(
namespace=namespace,
body=utils.get_pvc_spec(
pvc_name=constants.TRAINER_PVC_NAME,
namespace=namespace,
storage_size=storage_config["size"],
storage_class=storage_config["storage_class"],
),
)
except Exception as e:
print(e)

PVC_NAME = "train-job-pvc"
self.core_api.create_namespaced_persistent_volume_claim(
namespace=namespace,
body=utils.get_pvc_spec(
pvc_name=PVC_NAME,
namespace=namespace,
storage_size=storage_config["size"],
storage_class=storage_config["storage_class"],
),
)

if (
resources_per_worker["gpu"] is None and num_procs_per_worker != 0
) or num_procs_per_worker > resources_per_worker["gpu"]:
resources_per_worker["gpu"] is not None
and num_procs_per_worker > resources_per_worker["gpu"]
) or (resources_per_worker["gpu"] is None and num_procs_per_worker != 0):
raise ValueError("Insufficient gpu resources allocated to the container.")

if isinstance(model_provider_parameters, HuggingFaceModelParams):
mp = "hf"

if isinstance(dataset_provider_parameters, S3DatasetParams):
dp = "s3"
elif isinstance(dataset_provider_parameters, HfDatasetParams):
dp = "hf"

# create init container spec
init_container_spec = utils.get_container_spec(
name=constants.JOB_PARAMETERS[constants.PYTORCHJOB_KIND]["init_container"],
image=constants.JOB_PARAMETERS[constants.PYTORCHJOB_KIND][
"init_container_image"
],
name=constants.STORAGE_CONTAINER,
image=constants.STORAGE_CONTAINER_IMAGE,
args=[
"--model_provider",
mp,
Expand All @@ -161,19 +173,17 @@ def train(
json.dumps(dataset_provider_parameters.__dict__),
],
volume_mounts=[
models.V1VolumeMount(name="train_job_pv", mount_path="/workspace")
models.V1VolumeMount(name=constants.TRAINER_PV, mount_path="/workspace")
],
)

# create app container spec
container_spec = utils.get_container_spec(
name=constants.JOB_PARAMETERS[constants.PYTORCHJOB_KIND]["container"],
image=constants.JOB_PARAMETERS[constants.PYTORCHJOB_KIND][
"train_container_image"
],
image=constants.TRAINER_TRANSFORMER_IMAGE,
args=["--train_parameters", json.dumps(train_parameters.__dict__)],
volume_mounts=[
models.V1VolumeMount(name="train_job_pv", mount_path="/workspace")
models.V1VolumeMount(name=constants.TRAINER_PV, mount_path="/workspace")
],
resources=models.V1ResourceRequirements(
limits={
Expand All @@ -190,9 +200,9 @@ def train(
containers_spec=[container_spec],
volumes_spec=[
models.V1Volume(
name="train_job_pv",
name=constants.TRAINER_PV,
persistent_volume_claim=models.V1PersistentVolumeClaimVolumeSource(
claim_name=PVC_NAME
claim_name=constants.TRAINER_PVC_NAME
),
)
],
Expand All @@ -204,9 +214,9 @@ def train(
containers_spec=[init_container_spec, container_spec],
volumes_spec=[
models.V1Volume(
name="train_job_pv",
name=constants.TRAINER_PV,
persistent_volume_claim=models.V1PersistentVolumeClaimVolumeSource(
claim_name=PVC_NAME
claim_name=constants.TRAINER_PVC_NAME
),
)
],
Expand All @@ -217,7 +227,7 @@ def train(
namespace=namespace,
master_pod_template_spec=master_pod_template_spec,
worker_pod_template_spec=worker_pod_template_spec,
num_worker_replicas=num_workers,
num_worker_replicas=num_workers - 1,
num_procs_per_worker=num_procs_per_worker,
elastic_policy=models.KubeflowOrgV1ElasticPolicy(rdzv_backend="c10d"),
)
Expand Down
11 changes: 5 additions & 6 deletions sdk/python/kubeflow/training/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@
PYTORCHJOB_CONTAINER = "pytorch"
PYTORCHJOB_REPLICA_TYPES = (REPLICA_TYPE_MASTER.lower(), REPLICA_TYPE_WORKER.lower())
PYTORCHJOB_BASE_IMAGE = "docker.io/pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime"
STORAGE_CONTAINER = "pytorch-storage"
STORAGE_CONTAINER_IMAGE = "docker image path"
TRAINER_TRANSFORMER_IMAGE = "docker image path"
STORAGE_CONTAINER = "storage"
STORAGE_CONTAINER_IMAGE = "quay.io/deepanker_gupta/storage:v1"
TRAINER_TRANSFORMER_IMAGE = "quay.io/deepanker_gupta/trainer:v1"
TRAINER_PVC_NAME = "storage-initializer"
TRAINER_PV = "storage-pv"

# MXJob constants
MXJOB_KIND = "MXJob"
Expand Down Expand Up @@ -130,9 +132,6 @@
"plural": PYTORCHJOB_PLURAL,
"container": PYTORCHJOB_CONTAINER,
"base_image": PYTORCHJOB_BASE_IMAGE,
"init_container": STORAGE_CONTAINER,
"init_container_image": STORAGE_CONTAINER_IMAGE,
"train_container_image": TRAINER_TRANSFORMER_IMAGE,
},
MXJOB_KIND: {
"model": models.KubeflowOrgV1MXJob,
Expand Down
8 changes: 0 additions & 8 deletions sdk/python/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,3 @@ def get_pvc_spec(
pvc_spec.spec.storage_class_name = storage_class

return pvc_spec


def get_namespace_spec(namespace):
namespace_spec = models.V1Namespace(
api_version="v1", kind="Namespace", metadata=models.V1ObjectMeta(name=namespace)
)

return namespace_spec

0 comments on commit 3c9d65e

Please sign in to comment.