Skip to content

Commit

Permalink
adding types.py
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Jan 9, 2024
1 parent c27159b commit cc916f1
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 36 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ jobs:
src: sdk/

- name: Install dependencies
run: pip install pytest python-dateutil urllib3 kubernetes ./sdk/python
run: |
pip install pytest python-dateutil urllib3 kubernetes
pip install ./sdk/python/kubeflow/storage_init_container/requirements.txt
- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
22 changes: 1 addition & 21 deletions sdk/python/kubeflow/storage_init_container/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
from kubeflow.storage_init_container.types import *
from kubeflow.storage_init_container.abstract_model_provider import modelProvider
from kubeflow.storage_init_container.abstract_dataset_provider import datasetProvider
from dataclasses import dataclass, field
from urllib.parse import urlparse
import json, os
from datasets import load_dataset
from peft import LoraConfig
import transformers
from transformers import TrainingArguments
import enum
import huggingface_hub


class TRANSFORMER_TYPES(str, enum.Enum):
"""Types of Transformers."""

AutoModelForSequenceClassification = "AutoModelForSequenceClassification"
AutoModelForTokenClassification = "AutoModelForTokenClassification"
AutoModelForQuestionAnswering = "AutoModelForQuestionAnswering"
AutoModelForCausalLM = "AutoModelForCausalLM"
AutoModelForMaskedLM = "AutoModelForMaskedLM"
AutoModelForImageClassification = "AutoModelForImageClassification"


INIT_CONTAINER_MOUNT_PATH = "/workspace"

Expand Down
6 changes: 5 additions & 1 deletion sdk/python/kubeflow/storage_init_container/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ transformers_stream_generator==0.0.4
boto3==1.33.9
transformers>=4.20.0
peft>=0.3.0
huggingface_hub==0.16.4
huggingface_hub==0.16.4
datasets>=2.13.2
torch>=1.13.1
torchvision>=0.9.1
torchaudio>=0.8.1
6 changes: 1 addition & 5 deletions sdk/python/kubeflow/storage_init_container/s3.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .types import *
from kubeflow.storage_init_container.abstract_dataset_provider import datasetProvider
from dataclasses import dataclass, field
import json, os
import boto3
from urllib.parse import urlparse
import os


@dataclass
Expand Down
19 changes: 19 additions & 0 deletions sdk/python/kubeflow/storage_init_container/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field
from urllib.parse import urlparse
import json, os
from datasets import load_dataset
from peft import LoraConfig
import transformers
from transformers import TrainingArguments
import enum
import huggingface_hub
from typing import Union

TRANSFORMER_TYPES = Union[
transformers.AutoModelForSequenceClassification,
transformers.AutoModelForTokenClassification,
transformers.AutoModelForQuestionAnswering,
transformers.AutoModelForCausalLM,
transformers.AutoModelForMaskedLM,
transformers.AutoModelForImageClassification,
]
2 changes: 1 addition & 1 deletion sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def train(
"--model_uri",
model_provider_parameters.model_uri,
"--transformer_type",
model_provider_parameters.transformer_type,
model_provider_parameters.transformer_type.__class__.__name__,
"--model_dir",
model_provider_parameters.download_dir,
"--dataset_dir",
Expand Down
7 changes: 0 additions & 7 deletions sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@
"urllib3>=1.15.1",
"kubernetes>=23.6.0",
"retrying>=1.3.3",
"boto3>=1.33.9",
"transformers>=4.20.0",
"einops>=0.6.1",
"transformers_stream_generator>=0.0.4",
"peft>=0.3.0",
"datasets>=2.13.2",
"huggingface_hub>=0.16.4",
]

setuptools.setup(
Expand Down

0 comments on commit cc916f1

Please sign in to comment.