-
Notifications
You must be signed in to change notification settings - Fork 696
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding s3/nutanix object store as dataset provider
- Loading branch information
1 parent
cf226a0
commit 326763e
Showing
4 changed files
with
73 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
sdk/python/kubeflow/storage_init_container/abstract_dataset_provider.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class datasetProvider(ABC): | ||
@abstractmethod | ||
def load_config(self): | ||
pass | ||
|
||
@abstractmethod | ||
def download_dataset(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ torchvision==0.16.1 | |
torchaudio==2.1.1 | ||
einops==0.7.0 | ||
transformers_stream_generator==0.0.4 | ||
boto3==1.33.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from abstract_dataset_provider import datasetProvider | ||
from dataclasses import dataclass, field | ||
import json | ||
import boto3 | ||
from urllib.parse import urlparse | ||
|
||
|
||
@dataclass | ||
class S3DatasetParams: | ||
access_key: str | ||
secret_key: str | ||
endpoint_url: str | ||
bucket_name: str | ||
file_key: str | ||
region_name: str | ||
download_dir: str = field(default="/workspace/datasets") | ||
|
||
def is_valid_url(self, url): | ||
try: | ||
parsed_url = urlparse(url) | ||
print(parsed_url) | ||
return all([parsed_url.scheme, parsed_url.netloc]) | ||
except ValueError: | ||
return False | ||
|
||
def __post_init__(self): | ||
# Custom checks or validations can be added here | ||
self.is_valid_url(self.endpoint_url) | ||
|
||
|
||
class S3(datasetProvider): | ||
def load_config(self, serialised_args): | ||
self.config = S3DatasetParams(**json.loads(serialised_args)) | ||
|
||
def download_dataset(self): | ||
# Create an S3 client for Nutanix Object Store/S3 | ||
s3_client = boto3.client( | ||
"s3", | ||
aws_access_key_id=self.config.access_key, | ||
aws_secret_access_key=self.config.secret_key, | ||
endpoint_url=self.config.endpoint_url, | ||
region_name=self.config.egion_name, | ||
) | ||
|
||
# Download the file | ||
s3_client.download_file( | ||
self.config.bucket_name, self.config.file_key, self.config.download_dir | ||
) | ||
print(f"File downloaded to: {self.config.download_dir}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters