From ddb081de31ca050af51cf5b532a7a3c7160f12e4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 31 Mar 2022 11:37:04 -0700 Subject: [PATCH 1/2] add Beaker.create_dataset() method --- CHANGELOG.md | 4 +++ beaker/client.py | 82 +++++++++++++++++++++++++++++++++++++++----- beaker/exceptions.py | 4 +++ mypy.ini | 1 + 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73bc5c7..df5a671 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `Beaker.create_dataset()` method. + ## [v0.2.6](https://github.com/allenai/beaker-py/releases/tag/v0.2.6) - 2022-01-19 ### Added diff --git a/beaker/client.py b/beaker/client.py index 87a49fa..18634f0 100644 --- a/beaker/client.py +++ b/beaker/client.py @@ -1,8 +1,10 @@ import json +import os import urllib.parse from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Union import docker import requests @@ -12,10 +14,14 @@ from .config import Config from .exceptions import * +from .version import VERSION __all__ = ["Beaker"] +PathOrStr = Union[os.PathLike, Path] + + class Beaker: """ A client for interacting with `Beaker `_. @@ -67,20 +73,26 @@ def request( resource: str, method: str = "GET", query: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + data: Optional[Any] = None, exceptions_for_status: Optional[Dict[int, BeakerError]] = None, + headers: Optional[Dict[str, str]] = None, + token: Optional[str] = None, + base_url: Optional[str] = None, ) -> requests.Response: with self._session_with_backoff() as session: - url = f"{self.base_url}/{resource}" + url = f"{base_url or self.base_url}/{resource}" if query is not None: url = url + "?" + urllib.parse.urlencode(query) + default_headers = { + "Authorization": f"Bearer {token or self.config.user_token}", + "Content-Type": "application/json", + } + if headers is not None: + default_headers.update(headers) response = getattr(session, method.lower())( url, - headers={ - "Authorization": f"Bearer {self.config.user_token}", - "Content-Type": "application/json", - }, - data=None if data is None else json.dumps(data), + headers=default_headers, + data=json.dumps(data) if isinstance(data, dict) else data, ) if exceptions_for_status is not None and response.status_code in exceptions_for_status: raise exceptions_for_status[response.status_code] @@ -181,6 +193,60 @@ def get_dataset(self, dataset_id: str) -> Dict[str, Any]: f"datasets/{dataset_id}", exceptions_for_status={404: DatasetNotFound(dataset_id)} ).json() + def create_dataset( + self, + name: str, + source: PathOrStr, + target: Optional[str] = None, + workspace: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Create a dataset with the source file(s). + """ + workspace_name = workspace or self.config.default_workspace + if workspace_name is None: + raise ValueError("'workspace' argument required") + + # Ensure workspace exists. + self.ensure_workspace(workspace_name) + + # Ensure source exists. + source: Path = Path(source) + if not source.exists(): + raise FileNotFoundError(source) + + if not source.is_file(): + raise NotImplementedError("'create_dataset()' only works for single files so far") + + # Create the dataset. + dataset_info = self.request( + "datasets", + method="POST", + query={"name": name}, + data={"workspace": workspace_name, "fileheap": True}, + exceptions_for_status={409: DatasetConflict(name)}, + ).json() + + # Upload the file. + with source.open("rb") as source_file: + self.request( + f"datasets/{dataset_info['storage']['id']}/files/{target or source.name}", + method="PUT", + data=source_file, + token=dataset_info["storage"]["token"], + base_url=dataset_info["storage"]["address"], + headers={ + "User-Agent": f"beaker-py v{VERSION}", + }, + ) + + # Commit the dataset. + return self.request( + f"datasets/{dataset_info['id']}", + method="PATCH", + data={"commit": True}, + ).json() + def get_logs(self, job_id: str) -> Generator[bytes, None, None]: """ Download the logs for a job. diff --git a/beaker/exceptions.py b/beaker/exceptions.py index 418abad..a3315a7 100644 --- a/beaker/exceptions.py +++ b/beaker/exceptions.py @@ -34,6 +34,10 @@ class ExperimentConflict(BeakerError): pass +class DatasetConflict(BeakerError): + pass + + class DatasetNotFound(BeakerError): pass diff --git a/mypy.ini b/mypy.ini index 2bc5ea6..c65ce53 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,7 @@ [mypy] ignore_missing_imports = true no_site_packages = true +allow_redefinition = true [mypy-tests.*] strict_optional = false From 5456651fa9481f206f16b47573aefd56426e36aa Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 31 Mar 2022 12:27:45 -0700 Subject: [PATCH 2/2] add force option --- beaker/client.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/beaker/client.py b/beaker/client.py index 18634f0..f2154f8 100644 --- a/beaker/client.py +++ b/beaker/client.py @@ -190,15 +190,24 @@ def get_dataset(self, dataset_id: str) -> Dict[str, Any]: """ return self.request( - f"datasets/{dataset_id}", exceptions_for_status={404: DatasetNotFound(dataset_id)} + f"datasets/{urllib.parse.quote(dataset_id, safe='')}", + exceptions_for_status={404: DatasetNotFound(dataset_id)}, ).json() + def delete_dataset(self, dataset_id: str): + self.request( + f"datasets/{urllib.parse.quote(dataset_id, safe='')}", + method="DELETE", + exceptions_for_status={404: DatasetNotFound(dataset_id)}, + ) + def create_dataset( self, name: str, source: PathOrStr, target: Optional[str] = None, workspace: Optional[str] = None, + force: bool = False, ) -> Dict[str, Any]: """ Create a dataset with the source file(s). @@ -219,13 +228,23 @@ def create_dataset( raise NotImplementedError("'create_dataset()' only works for single files so far") # Create the dataset. - dataset_info = self.request( - "datasets", - method="POST", - query={"name": name}, - data={"workspace": workspace_name, "fileheap": True}, - exceptions_for_status={409: DatasetConflict(name)}, - ).json() + def make_dataset() -> Dict[str, Any]: + return self.request( + "datasets", + method="POST", + query={"name": name}, + data={"workspace": workspace_name, "fileheap": True}, + exceptions_for_status={409: DatasetConflict(name)}, + ).json() + + try: + dataset_info = make_dataset() + except DatasetConflict: + if force: + self.delete_dataset(f"{self.user}/{name}") + dataset_info = make_dataset() + else: + raise # Upload the file. with source.open("rb") as source_file: