Skip to content

Commit

Permalink
Convert arg strings to enum values (#644)
Browse files Browse the repository at this point in the history
Fixes #641 as a side effect.
  • Loading branch information
stevemessick authored Oct 2, 2024
1 parent 63244ad commit 3a3c549
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 269 deletions.
183 changes: 121 additions & 62 deletions kaggle/api/kaggle_api_extended.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/python
#
# Copyright 2024 Kaggle Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/python
#
# Copyright 2024 Kaggle Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/python
#
# Copyright 2019 Kaggle Inc
Expand Down Expand Up @@ -56,10 +56,18 @@
from kaggle.configuration import Configuration
from kagglesdk import KaggleClient, KaggleEnv
from kagglesdk.competitions.types.competition_api_service import *
from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, ApiListDatasetFilesRequest, \
ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, ApiDatasetNewFile
from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, DatasetSortBy
from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, \
ApiListDatasetFilesRequest, \
ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, \
ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, \
ApiDatasetNewFile, ApiUpdateDatasetMetadataRequest, \
ApiGetDatasetMetadataRequest
from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, \
DatasetSortBy, DatasetFileTypeGroup, DatasetLicenseGroup
from kagglesdk.datasets.types.dataset_types import DatasetSettings, \
SettingsLicense, UserRole, DatasetSettingsFile
from kagglesdk.kernels.types.kernels_api_service import ApiListKernelsRequest
from .kaggle_api import KaggleApi
from ..api_client import ApiClient
from ..models.api_blob_type import ApiBlobType
Expand Down Expand Up @@ -313,14 +321,17 @@ class KaggleApi(KaggleApi):
]

# Competitions valid types
valid_competition_groups = ['general', 'entered', 'inClass']
valid_competition_groups = [
'general', 'entered', 'community', 'hosted', 'unlaunched',
'unlaunched_community'
]
valid_competition_categories = [
'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters',
'playground'
]
valid_competition_sort_by = [
'grouped', 'prize', 'earliestDeadline', 'latestDeadline', 'numberOfTeams',
'recentlyCreated'
'grouped', 'best', 'prize', 'earliestDeadline', 'latestDeadline',
'numberOfTeams', 'relevance', 'recentlyCreated'
]

# Datasets valid types
Expand Down Expand Up @@ -709,6 +720,10 @@ def camel_to_snake(self, name):
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()

def lookup_enum(self, enum_class, item_name):
prefix = self.camel_to_snake(enum_class.__name__).upper()
return enum_class[f'{prefix}_{self.camel_to_snake(item_name).upper()}']

## Competitions

def competitions_list(self,
Expand All @@ -726,20 +741,29 @@ def competitions_list(self,
page: the page to return (default is 1)
search: a search term to use (default is empty string)
sort_by: how to sort the result, see valid_competition_sort_by for options
category: category to filter result to
category: category to filter result to; use 'all' to get closed competitions
group: group to filter result to
"""
if group and group not in self.valid_competition_groups:
raise ValueError('Invalid group specified. Valid options are ' +
str(self.valid_competition_groups))
if group:
if group not in self.valid_competition_groups:
raise ValueError('Invalid group specified. Valid options are ' +
str(self.valid_competition_groups))
if group == 'all':
group = CompetitionListTab.COMPETITION_LIST_TAB_DEFAULT
else:
group = self.lookup_enum(CompetitionListTab, group)

if category and category not in self.valid_competition_categories:
raise ValueError('Invalid category specified. Valid options are ' +
str(self.valid_competition_categories))
if category:
if category not in self.valid_competition_categories:
raise ValueError('Invalid category specified. Valid options are ' +
str(self.valid_competition_categories))
category = self.lookup_enum(HostSegment, category)

if sort_by and sort_by not in self.valid_competition_sort_by:
raise ValueError('Invalid sort_by specified. Valid options are ' +
str(self.valid_competition_sort_by))
if sort_by:
if sort_by not in self.valid_competition_sort_by:
raise ValueError('Invalid sort_by specified. Valid options are ' +
str(self.valid_competition_sort_by))
sort_by = self.lookup_enum(CompetitionSortBy, sort_by)

with self.build_kaggle_client() as kaggle:
request = ApiListCompetitionsRequest()
Expand Down Expand Up @@ -1199,30 +1223,36 @@ def dataset_list(self,
raise ValueError('Invalid sort by specified. Valid options are ' +
str(self.valid_dataset_sort_bys))
else:
sort_by = DatasetSortBy[f"DATASET_SORT_BY_{sort_by.upper()}"]
sort_by = self.lookup_enum(DatasetSortBy, sort_by)

if size:
raise ValueError(
'The --size parameter has been deprecated. ' +
'Please use --max-size and --min-size to filter dataset sizes.')

if file_type and file_type not in self.valid_dataset_file_types:
raise ValueError('Invalid file type specified. Valid options are ' +
str(self.valid_dataset_file_types))
if file_type:
if file_type not in self.valid_dataset_file_types:
raise ValueError('Invalid file type specified. Valid options are ' +
str(self.valid_dataset_file_types))
else:
file_type = self.lookup_enum(DatasetFileTypeGroup, file_type)

if license_name and license_name not in self.valid_dataset_license_names:
raise ValueError('Invalid license specified. Valid options are ' +
str(self.valid_dataset_license_names))
if license_name:
if license_name not in self.valid_dataset_license_names:
raise ValueError('Invalid license specified. Valid options are ' +
str(self.valid_dataset_license_names))
else:
license_name = self.lookup_enum(DatasetLicenseGroup, license_name)

if int(page) <= 0:
raise ValueError('Page number must be >= 1')

if max_size and min_size:
if (int(max_size) < int(min_size)):
if int(max_size) < int(min_size):
raise ValueError('Max Size must be max_size >= min_size')
if (max_size and int(max_size) <= 0):
if max_size and int(max_size) <= 0:
raise ValueError('Max Size must be > 0')
elif (min_size and int(min_size) < 0):
elif min_size and int(min_size) < 0:
raise ValueError('Min Size must be >= 0')

group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_PUBLIC
Expand Down Expand Up @@ -1315,43 +1345,57 @@ def dataset_metadata_update(self, dataset, path):
effective_path) = self.dataset_metadata_prep(dataset, path)
meta_file = self.get_dataset_metadata_file(effective_path)
with open(meta_file, 'r') as f:
metadata = json.load(f)
s = json.load(f)
metadata = json.loads(s)
updateSettingsRequest = DatasetUpdateSettingsRequest(
title=metadata['title'],
subtitle=metadata['subtitle'],
description=metadata['description'],
is_private=metadata['isPrivate'],
licenses=[License(name=l['name']) for l in metadata['licenses']],
keywords=metadata['keywords'],
title=metadata.get('title') or '',
subtitle=metadata.get('subtitle') or '',
description=metadata.get('description') or '',
is_private=metadata.get('isPrivate') or False,
licenses=[License(name=l['name']) for l in metadata['licenses']] if metadata.get('licenses') else [],
keywords=metadata.get('keywords'),
collaborators=[
Collaborator(username=c['username'], role=c['role'])
for c in metadata['collaborators']
],
data=metadata['data'])
Collaborator(username=c['username'], role=c['role'])
for c in metadata['collaborators']
] if metadata.get('collaborators') else [],
data=metadata.get('data'))
result = self.process_response(
self.metadata_post_with_http_info(owner_slug, dataset_slug,
updateSettingsRequest))
if (len(result['errors']) > 0):
[print(e['message']) for e in result['errors']]
exit(1)

def new_license(self, name):
slicense = SettingsLicense()
slicense.name = name
return slicense

def new_collaborator(self, name, role):
collab = UserRole()
collab.username = name
collab.role = role
return collab

def dataset_metadata(self, dataset, path):
(owner_slug, dataset_slug,
effective_path) = self.dataset_metadata_prep(dataset, path)

if not os.path.exists(effective_path):
os.makedirs(effective_path)

result = self.process_response(
self.metadata_get_with_http_info(owner_slug, dataset_slug))
if (result['errorMessage']):
raise Exception(result['errorMessage'])

metadata = Metadata(result['info'])
with self.build_kaggle_client() as kaggle:
request = ApiGetDatasetMetadataRequest()
request.owner_slug = owner_slug
request.dataset_slug = dataset_slug
response = kaggle.datasets.dataset_api_client.get_dataset_metadata(
request)
if response.error_message:
raise Exception(response.error_message)

meta_file = os.path.join(effective_path, self.DATASET_METADATA_FILE)
with open(meta_file, 'w') as f:
json.dump(metadata, f, indent=2, default=lambda o: o.__dict__)
json.dump(response.to_json(response.info), f, indent=2, default=lambda o: o.__dict__)

return meta_file

Expand Down Expand Up @@ -2109,7 +2153,7 @@ def kernels_list(self,
kernel_type=None,
output_type=None,
sort_by=None):
""" list kernels based on a set of search criteria
""" List kernels based on a set of search criteria.
Parameters
==========
Expand Down Expand Up @@ -2161,6 +2205,21 @@ def kernels_list(self,
if mine:
group = 'profile'

with self.build_kaggle_client() as kaggle:
request = ApiListKernelsRequest()
request.page = page
page_size = page_size
group = group # req
user = user
language = language
kernel_type = kernel_type
output_type = output_type
sort_by = sort_by #req
dataset = dataset
competition = competition
parent_kernel = parent_kernel
search = search

kernels_list_result = self.process_response(
self.kernels_list_with_http_info(
page=page,
Expand Down
1 change: 1 addition & 0 deletions kagglesdk/competitions/types/competition_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class CompetitionListTab(enum.Enum):
COMPETITION_LIST_TAB_HOSTED = 3
COMPETITION_LIST_TAB_UNLAUNCHED = 4
COMPETITION_LIST_TAB_UNLAUNCHED_COMMUNITY = 5
COMPETITION_LIST_TAB_EVERYTHING = 6

class CompetitionSortBy(enum.Enum):
COMPETITION_SORT_BY_GROUPED = 0
Expand Down
2 changes: 1 addition & 1 deletion kagglesdk/kaggle_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _init_session(self):
})

self._try_fill_auth()
self._fill_xsrf_token(iap_token)
# self._fill_xsrf_token(iap_token) # TODO Make this align with original handler.

def _get_iap_token_if_required(self):
if self._env not in (KaggleEnv.STAGING, KaggleEnv.ADMIN):
Expand Down
Loading

0 comments on commit 3a3c549

Please sign in to comment.