Skip to content

Commit

Permalink
adding s3/nutanix object store as dataset provider
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Dec 13, 2023
1 parent cf226a0 commit 326763e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod


class datasetProvider(ABC):
@abstractmethod
def load_config(self):
pass

@abstractmethod
def download_dataset(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torchvision==0.16.1
torchaudio==2.1.1
einops==0.7.0
transformers_stream_generator==0.0.4
boto3==1.33.9
49 changes: 49 additions & 0 deletions sdk/python/kubeflow/storage_init_container/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from abstract_dataset_provider import datasetProvider
from dataclasses import dataclass, field
import json
import boto3
from urllib.parse import urlparse


@dataclass
class S3DatasetParams:
access_key: str
secret_key: str
endpoint_url: str
bucket_name: str
file_key: str
region_name: str
download_dir: str = field(default="/workspace/datasets")

def is_valid_url(self, url):
try:
parsed_url = urlparse(url)
print(parsed_url)
return all([parsed_url.scheme, parsed_url.netloc])
except ValueError:
return False

def __post_init__(self):
# Custom checks or validations can be added here
self.is_valid_url(self.endpoint_url)


class S3(datasetProvider):
def load_config(self, serialised_args):
self.config = S3DatasetParams(**json.loads(serialised_args))

def download_dataset(self):
# Create an S3 client for Nutanix Object Store/S3
s3_client = boto3.client(
"s3",
aws_access_key_id=self.config.access_key,
aws_secret_access_key=self.config.secret_key,
endpoint_url=self.config.endpoint_url,
region_name=self.config.egion_name,
)

# Download the file
s3_client.download_file(
self.config.bucket_name, self.config.file_key, self.config.download_dir
)
print(f"File downloaded to: {self.config.download_dir}")
12 changes: 12 additions & 0 deletions sdk/python/kubeflow/storage_init_container/storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from hugging_face import HuggingFace
from s3 import S3


def model_factory(model_provider, model_provider_args):
Expand All @@ -12,6 +13,16 @@ def model_factory(model_provider, model_provider_args):
return "This is the default case"


def dataset_factory(dataset_provider, dataset_provider_args):
match dataset_provider:
case "s3":
s3 = S3()
s3.load_config(dataset_provider_args)
s3.download_dataset()
case _:
return "This is the default case"


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="script for downloading model and datasets to PVC."
Expand All @@ -28,3 +39,4 @@ def model_factory(model_provider, model_provider_args):
args = parser.parse_args()

model_factory(args.model_provider, args.model_provider_args)
dataset_factory(args.dataset_provider, args.dataset_provider_args)

0 comments on commit 326763e

Please sign in to comment.