diff --git a/devbin/generate_gcp_ar_cleanup_policy.py b/devbin/generate_gcp_ar_cleanup_policy.py index f8dcc4f1591..b027e9fb849 100644 --- a/devbin/generate_gcp_ar_cleanup_policy.py +++ b/devbin/generate_gcp_ar_cleanup_policy.py @@ -12,15 +12,17 @@ def to_dict(self): class DeletePolicy(CleanupPolicy): - def __init__(self, - name: str, - tag_state: str, - *, - tag_prefixes: Optional[List[str]] = None, - version_name_prefixes: Optional[List[str]] = None, - package_name_prefixes: Optional[List[str]] = None, - older_than: Optional[str] = None, - newer_than: Optional[str] = None): + def __init__( + self, + name: str, + tag_state: str, + *, + tag_prefixes: Optional[List[str]] = None, + version_name_prefixes: Optional[List[str]] = None, + package_name_prefixes: Optional[List[str]] = None, + older_than: Optional[str] = None, + newer_than: Optional[str] = None + ): self.name = name self.tag_state = tag_state self.tag_prefixes = tag_prefixes @@ -46,15 +48,17 @@ def to_dict(self): class ConditionalKeepPolicy(CleanupPolicy): - def __init__(self, - name: str, - tag_state: str, - *, - tag_prefixes: Optional[List[str]] = None, - version_name_prefixes: Optional[List[str]] = None, - package_name_prefixes: Optional[List[str]] = None, - older_than: Optional[str] = None, - newer_than: Optional[str] = None): + def __init__( + self, + name: str, + tag_state: str, + *, + tag_prefixes: Optional[List[str]] = None, + version_name_prefixes: Optional[List[str]] = None, + package_name_prefixes: Optional[List[str]] = None, + older_than: Optional[str] = None, + newer_than: Optional[str] = None + ): self.name = name self.tag_state = tag_state self.tag_prefixes = tag_prefixes @@ -80,10 +84,7 @@ def to_dict(self): class MostRecentVersionKeepPolicy(CleanupPolicy): - def __init__(self, - name: str, - package_name_prefixes: List[str], - keep_count: int): + def __init__(self, name: str, package_name_prefixes: List[str], keep_count: int): self.name = name self.package_name_prefixes = package_name_prefixes self.keep_count = keep_count @@ -92,10 +93,7 @@ def to_dict(self): data = { 'name': self.name, 'action': {'type': 'Keep'}, - 'mostRecentVersions': { - 'packageNamePrefixes': self.package_name_prefixes, - 'keepCount': self.keep_count - } + 'mostRecentVersions': {'packageNamePrefixes': self.package_name_prefixes, 'keepCount': self.keep_count}, } return data diff --git a/hail/python/hailtop/batch/job.py b/hail/python/hailtop/batch/job.py index 9085ffe3fb8..a38be18c58b 100644 --- a/hail/python/hailtop/batch/job.py +++ b/hail/python/hailtop/batch/job.py @@ -5,7 +5,7 @@ import textwrap import warnings from shlex import quote as shq -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast, Literal from typing_extensions import Self import hailtop.batch_client.client as bc @@ -878,6 +878,19 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): return True +UnpreparedArg = Union['_resource.ResourceType', List['UnpreparedArg'], Tuple['UnpreparedArg', ...], Dict[str, 'UnpreparedArg'], Any] + +PreparedArg = Union[ + Tuple[Literal['py_path'], str], + Tuple[Literal['path'], str], + Tuple[Literal['dict_path'], Dict[str, str]], + Tuple[Literal['list'], List['PreparedArg']], + Tuple[Literal['dict'], Dict[str, 'PreparedArg']], + Tuple[Literal['tuple'], Tuple['PreparedArg', ...]], + Tuple[Literal['value'], Any] +] + + class PythonJob(Job): """ Object representing a single Python job to execute. @@ -924,7 +937,7 @@ def __init__(self, super().__init__(batch, token, name=name, attributes=attributes, shell=None) self._resources: Dict[str, _resource.Resource] = {} self._resources_inverse: Dict[_resource.Resource, str] = {} - self._function_calls: List[Tuple[_resource.PythonResult, int, Tuple[Any, ...], Dict[str, Any]]] = [] + self._function_calls: List[Tuple[_resource.PythonResult, int, Tuple[UnpreparedArg, ...], Dict[str, UnpreparedArg]]] = [] self.n_results = 0 def _get_python_resource(self, item: str) -> '_resource.PythonResult': @@ -970,7 +983,7 @@ def image(self, image: str) -> 'PythonJob': self._image = image return self - def call(self, unapplied: Callable, *args, **kwargs) -> '_resource.PythonResult': + def call(self, unapplied: Callable, *args: UnpreparedArg, **kwargs: UnpreparedArg) -> '_resource.PythonResult': """Execute a Python function. Examples @@ -1148,7 +1161,7 @@ def handle_args(r): return result async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): - def prepare_argument_for_serialization(arg): + def preserialize(arg: UnpreparedArg) -> PreparedArg: if isinstance(arg, _resource.PythonResult): return ('py_path', arg._get_path(local_tmpdir)) if isinstance(arg, _resource.ResourceFile): @@ -1156,20 +1169,24 @@ def prepare_argument_for_serialization(arg): if isinstance(arg, _resource.ResourceGroup): return ('dict_path', {name: resource._get_path(local_tmpdir) for name, resource in arg._resources.items()}) - if isinstance(arg, (list, tuple)): - return ('value', [prepare_argument_for_serialization(elt) for elt in arg]) + if isinstance(arg, list): + return ('list', [preserialize(elt) for elt in arg]) + if isinstance(arg, tuple): + return ('tuple', tuple((preserialize(elt) for elt in arg))) if isinstance(arg, dict): - return ('value', {k: prepare_argument_for_serialization(v) for k, v in arg.items()}) + return ('dict', {k: preserialize(v) for k, v in arg.items()}) return ('value', arg) for i, (result, unapplied_id, args, kwargs) in enumerate(self._function_calls): func_file = self._batch._python_function_files[unapplied_id] - prepared_args = prepare_argument_for_serialization(args)[1] - prepared_kwargs = prepare_argument_for_serialization(kwargs)[1] + preserialized_args = [preserialize(arg) for arg in args] + del args + preserialized_kwargs = {keyword: preserialize(arg) for keyword, arg in kwargs.items()} + del kwargs args_file = await self._batch._serialize_python_to_input_file( - os.path.dirname(result._get_path(remote_tmpdir)), "args", i, (prepared_args, prepared_kwargs), dry_run + os.path.dirname(result._get_path(remote_tmpdir)), "args", i, (preserialized_args, preserialized_kwargs), dry_run ) json_write, str_write, repr_write = [ @@ -1191,14 +1208,16 @@ def prepare_argument_for_serialization(arg): def deserialize_argument(arg): typ, val = arg - if typ == 'value' and isinstance(val, dict): - return {{k: deserialize_argument(v) for k, v in val.items()}} - if typ == 'value' and isinstance(val, (list, tuple)): - return [deserialize_argument(elt) for elt in val] if typ == 'py_path': return dill.load(open(val, 'rb')) if typ in ('path', 'dict_path'): return val + if typ == 'list': + return [deserialize_argument(elt) for elt in val] + if typ == 'tuple': + return tuple((deserialize_argument(elt) for elt in val)) + if typ == 'dict': + return {{k: deserialize_argument(v) for k, v in val.items()}} assert typ == 'value' return val @@ -1226,8 +1245,8 @@ def deserialize_argument(arg): unapplied = self._batch._python_function_defs[unapplied_id] self._user_code.append(textwrap.dedent(inspect.getsource(unapplied))) - args_str = ', '.join([f'{arg!r}' for _, arg in prepared_args]) - kwargs_str = ', '.join([f'{k}={v!r}' for k, (_, v) in kwargs.items()]) + args_str = ', '.join([f'{arg!r}' for _, arg in preserialized_args]) + kwargs_str = ', '.join([f'{k}={v!r}' for k, (_, v) in preserialized_kwargs.items()]) separator = ', ' if args_str and kwargs_str else '' func_call = f'{unapplied.__name__}({args_str}{separator}{kwargs_str})' self._user_code.append(self._interpolate_command(func_call, allow_python_results=True)) diff --git a/hail/python/hailtop/batch/resource.py b/hail/python/hailtop/batch/resource.py index b8df03ddca0..799d80e8daf 100644 --- a/hail/python/hailtop/batch/resource.py +++ b/hail/python/hailtop/batch/resource.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Set, cast +from typing import Optional, Set, cast, Union from . import job # pylint: disable=cyclic-import from .exceptions import BatchException @@ -448,3 +448,6 @@ def __str__(self): def __repr__(self): return self._uid # pylint: disable=no-member + + +ResourceType = Union[PythonResult, ResourceFile, ResourceGroup] diff --git a/hail/python/test/hailtop/batch/test_batch.py b/hail/python/test/hailtop/batch/test_batch.py index e66228ddbd6..efcc88d9d97 100644 --- a/hail/python/test/hailtop/batch/test_batch.py +++ b/hail/python/test/hailtop/batch/test_batch.py @@ -10,7 +10,10 @@ from shlex import quote as shq import uuid import re +import orjson +import hailtop.fs as hfs +import hailtop.batch_client.client as bc from hailtop import pip_version from hailtop.batch import Batch, ServiceBackend, LocalBackend, ResourceGroup from hailtop.batch.resource import JobResourceFile @@ -1291,6 +1294,47 @@ def test_update_batch_from_batch_id(self): res_status = res.status() assert res_status['state'] == 'success', str((res_status, res.debug_info())) + def test_python_job_with_kwarg(self): + def foo(*, kwarg): + return kwarg + + b = self.batch(default_python_image=PYTHON_DILL_IMAGE) + j = b.new_python_job() + r = j.call(foo, kwarg='hello world') + + output_path = f'{self.cloud_output_dir}/test_python_job_with_kwarg' + b.write_output(r.as_json(), output_path) + res = b.run() + assert isinstance(res, bc.Batch) + + assert res.status()['state'] == 'success', str((res, res.debug_info())) + with hfs.open(output_path) as f: + assert orjson.loads(f.read()) == 'hello world' + + def test_tuple_recursive_resource_extraction_in_python_jobs(self): + b = self.batch(default_python_image=PYTHON_DILL_IMAGE) + + def write(paths): + if not isinstance(paths, tuple): + raise ValueError('paths must be a tuple') + for i, path in enumerate(paths): + with open(path, 'w') as f: + f.write(f'{i}') + + head = b.new_python_job() + head.call(write, (head.ofile1, head.ofile2)) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) + def test_list_recursive_resource_extraction_in_python_jobs(self): b = self.batch(default_python_image=PYTHON_DILL_IMAGE) diff --git a/hail/scripts/test_requester_pays_parsing.py b/hail/scripts/test_requester_pays_parsing.py index d73f09b30dc..d8c5794d377 100644 --- a/hail/scripts/test_requester_pays_parsing.py +++ b/hail/scripts/test_requester_pays_parsing.py @@ -5,7 +5,7 @@ from hailtop.aiocloud.aiogoogle import get_gcs_requester_pays_configuration from hailtop.aiocloud.aiogoogle.user_config import get_spark_conf_gcs_requester_pays_configuration, spark_conf_path -from hailtop.config.user_config import ConfigVariable, configuration_of +from hailtop.config import ConfigVariable, configuration_of from hailtop.utils.process import check_exec_output if 'YOU_MAY_OVERWRITE_MY_SPARK_DEFAULTS_CONF_AND_HAILCTL_SETTINGS' not in os.environ: