diff --git a/batch/batch/cloud/azure/driver/pricing.py b/batch/batch/cloud/azure/driver/pricing.py index b3fa23b9413..8fefc3fdf4e 100644 --- a/batch/batch/cloud/azure/driver/pricing.py +++ b/batch/batch/cloud/azure/driver/pricing.py @@ -146,7 +146,7 @@ async def managed_disk_prices_by_region( prices: List[AzureDiskPrice] = [] seen_disk_names: Dict[str, str] = {} - filter = f'serviceName eq \'Storage\' and armRegionName eq \'{region}\' and endswith(meterName,\'Disks\')' + filter = f'serviceName eq \'Storage\' and armRegionName eq \'{region}\' and ( endswith(meterName,\'Disk\') or endswith(meterName,\'Disks\') )' async for data in pricing_client.list_prices(filter=filter): if data['type'] != 'Consumption' or data['productName'] == 'Premium Page Blob': continue diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 9a8ba2451a6..31a96005b75 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -370,6 +370,10 @@ class ImageNotFound(Exception): pass +class InvalidImageRepository(Exception): + pass + + class Image: def __init__( self, @@ -441,6 +445,8 @@ async def _pull_image(self): raise ImageCannotBePulled from e if 'not found: manifest unknown' in e.message: raise ImageNotFound from e + if 'Invalid repository name' in e.message: + raise InvalidImageRepository from e raise image_config, _ = await check_exec_output('docker', 'inspect', self.image_ref_str) @@ -580,7 +586,7 @@ def user_error(e): # bucket name and your credentials.\n') if b'Bad credentials for bucket' in e.stderr: return True - if isinstance(e, (ImageNotFound, ImageCannotBePulled)): + if isinstance(e, (ImageNotFound, ImageCannotBePulled, InvalidImageRepository)): return True if isinstance(e, (ContainerTimeoutError, ContainerDeletedError)): return True @@ -671,6 +677,8 @@ async def create(self): self.short_error = 'image not found' elif isinstance(e, ImageCannotBePulled): self.short_error = 'image cannot be pulled' + elif isinstance(e, InvalidImageRepository): + self.short_error = 'image repository is invalid' self.state = 'error' self.error = traceback.format_exc() diff --git a/batch/test/failure_injecting_client_session.py b/batch/test/failure_injecting_client_session.py index 3c94e7a070f..ac194d54bc4 100644 --- a/batch/test/failure_injecting_client_session.py +++ b/batch/test/failure_injecting_client_session.py @@ -1,13 +1,13 @@ import aiohttp -from hailtop.httpx import client_session +from hailtop import httpx from hailtop.utils import async_to_blocking -class FailureInjectingClientSession: +class FailureInjectingClientSession(httpx.ClientSession): def __init__(self, should_fail): self.should_fail = should_fail - self.real_session = client_session() + self.real_session = httpx.client_session() def __enter__(self): return self diff --git a/batch/test/test_accounts.py b/batch/test/test_accounts.py index 157c0195021..8a9066d56c7 100644 --- a/batch/test/test_accounts.py +++ b/batch/test/test_accounts.py @@ -75,12 +75,12 @@ async def test_bad_token(): token = session_id_encode_to_str(secrets.token_bytes(32)) bc = await BatchClient.create('test', _token=token) try: - b = bc.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['false']) - await b.submit() + bb = bc.create_batch() + bb.create_job(DOCKER_ROOT_IMAGE, ['false']) + b = await bb.submit() assert False, str(await b.debug_info()) except aiohttp.ClientResponseError as e: - assert e.status == 401, str((e, await b.debug_info())) + assert e.status == 401 finally: await bc.close() @@ -398,11 +398,11 @@ async def test_billing_limit_zero( try: bb = client.create_batch() - batch = await bb.submit() + b = await bb.submit() except aiohttp.ClientResponseError as e: - assert e.status == 403 and 'has exceeded the budget' in e.message, str(await batch.debug_info()) + assert e.status == 403 and 'has exceeded the budget' in e.message else: - assert False, str(await batch.debug_info()) + assert False, str(await b.debug_info()) async def test_billing_limit_tiny( @@ -531,81 +531,81 @@ async def test_batch_cannot_be_accessed_by_users_outside_the_billing_project( assert r['billing_project'] == project user1_client = await make_client(project) - b = user1_client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, command=['sleep', '30']) - b_handle = await b.submit() + bb = user1_client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, command=['sleep', '30']) + b = await bb.submit() user2_client = dev_client - user2_batch = Batch(user2_client, b_handle.id, b_handle.attributes, b_handle.n_jobs, b_handle.token) + user2_batch = Batch(user2_client, b.id, b.attributes, b.n_jobs, b.token) try: try: - await user2_client.get_batch(b_handle.id) + await user2_client.get_batch(b.id) except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info) + assert False, str(await b.debug_info()) try: await user2_client.get_job(j.batch_id, j.job_id) except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) try: await user2_client.get_job_log(j.batch_id, j.job_id) except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) try: await user2_client.get_job_attempts(j.batch_id, j.job_id) except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) try: await user2_batch.status() except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) try: await user2_batch.cancel() except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) try: await user2_batch.delete() except aiohttp.ClientResponseError as e: - assert e.status == 404, str((e, await b_handle.debug_info())) + assert e.status == 404, str((e, await b.debug_info())) else: - assert False, str(await b_handle.debug_info()) + assert False, str(await b.debug_info()) # list batches results for user2 - found, batches = await search_batches(user2_client, b_handle.id, q='') - assert not found, str((b_handle.id, batches, await b_handle.debug_info())) + found, batches = await search_batches(user2_client, b.id, q='') + assert not found, str((b.id, batches, await b.debug_info())) - found, batches = await search_batches(user2_client, b_handle.id, q=f'billing_project:{project}') - assert not found, str((b_handle.id, batches, await b_handle.debug_info())) + found, batches = await search_batches(user2_client, b.id, q=f'billing_project:{project}') + assert not found, str((b.id, batches, await b.debug_info())) - found, batches = await search_batches(user2_client, b_handle.id, q='user:test') - assert not found, str((b_handle.id, batches, await b_handle.debug_info())) + found, batches = await search_batches(user2_client, b.id, q='user:test') + assert not found, str((b.id, batches, await b.debug_info())) - found, batches = await search_batches(user2_client, b_handle.id, q=None) - assert not found, str((b_handle.id, batches, await b_handle.debug_info())) + found, batches = await search_batches(user2_client, b.id, q=None) + assert not found, str((b.id, batches, await b.debug_info())) - found, batches = await search_batches(user2_client, b_handle.id, q='user:test-dev') - assert not found, str((b_handle.id, batches, await b_handle.debug_info())) + found, batches = await search_batches(user2_client, b.id, q='user:test-dev') + assert not found, str((b.id, batches, await b.debug_info())) finally: - await b_handle.delete() + await b.delete() async def test_deleted_open_batches_do_not_prevent_billing_project_closure( @@ -613,8 +613,8 @@ async def test_deleted_open_batches_do_not_prevent_billing_project_closure( dev_client: BatchClient, random_billing_project_name: Callable[[], str], ): + project = await dev_client.create_billing_project(random_billing_project_name) try: - project = await dev_client.create_billing_project(random_billing_project_name) await dev_client.add_user('test', project) client = await make_client(project) open_batch = await client.create_batch()._open_batch() diff --git a/batch/test/test_batch.py b/batch/test/test_batch.py index ae299e7a59c..10405c849bb 100644 --- a/batch/test/test_batch.py +++ b/batch/test/test_batch.py @@ -13,7 +13,7 @@ from hailtop.utils import external_requests_client_session, retry_response_returning_functions, sync_sleep_and_backoff from .failure_injecting_client_session import FailureInjectingClientSession -from .utils import fails_in_azure, legacy_batch_status, skip_in_azure, smallest_machine_type +from .utils import legacy_batch_status, skip_in_azure, smallest_machine_type deploy_config = get_deploy_config() @@ -34,9 +34,9 @@ def client(): def test_job(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) + b = bb.submit() status = j.wait() assert 'attributes' not in status, str((status, b.debug_info())) @@ -48,9 +48,9 @@ def test_job(client: BatchClient): def test_job_running_logs(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['bash', '-c', 'echo test && sleep 300']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['bash', '-c', 'echo test && sleep 300']) + b = bb.submit() delay = 1 while True: @@ -67,9 +67,9 @@ def test_job_running_logs(client: BatchClient): def test_exit_code_duration(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['bash', '-c', 'exit 7']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['bash', '-c', 'exit 7']) + b = bb.submit() status = j.wait() assert status['exit_code'] == 7, str((status, b.debug_info())) assert isinstance(status['duration'], int), str((status, b.debug_info())) @@ -78,16 +78,16 @@ def test_exit_code_duration(client: BatchClient): def test_attributes(client: BatchClient): a = {'name': 'test_attributes', 'foo': 'bar'} - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], attributes=a) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], attributes=a) + b = bb.submit() assert j.attributes() == a, str(b.debug_info()) def test_garbage_image(client: BatchClient): - builder = client.create_batch() - j = builder.create_job('dsafaaadsf', ['echo', 'test']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job('dsafaaadsf', ['echo', 'test']) + b = bb.submit() status = j.wait() assert j._get_exit_codes(status) == {'main': None}, str((status, b.debug_info())) assert j._get_error(status, 'main') is not None, str((status, b.debug_info())) @@ -95,74 +95,74 @@ def test_garbage_image(client: BatchClient): def test_bad_command(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['sleep 5']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep 5']) + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) def test_invalid_resource_requests(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '1', 'memory': '250Gi', 'storage': '1Gi'} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises(aiohttp.client.ClientResponseError, match='resource requests.*unsatisfiable'): - builder.submit() + bb.submit() - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0', 'memory': '1Gi', 'storage': '1Gi'} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises( aiohttp.client.ClientResponseError, match='bad resource request for job.*cpu must be a power of two with a min of 0.25; found.*', ): - builder.submit() + bb.submit() - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.1', 'memory': '1Gi', 'storage': '1Gi'} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises( aiohttp.client.ClientResponseError, match='bad resource request for job.*cpu must be a power of two with a min of 0.25; found.*', ): - builder.submit() + bb.submit() - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': 'foo', 'storage': '1Gi'} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises( aiohttp.client.ClientResponseError, match=".*.resources.memory must match regex:.*.resources.memory must be one of:.*", ): - builder.submit() + bb.submit() - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '500Mi', 'storage': '10000000Gi'} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises(aiohttp.client.ClientResponseError, match='resource requests.*unsatisfiable'): - builder.submit() + bb.submit() - builder = client.create_batch() + bb = client.create_batch() resources = {'storage': '10000000Gi', 'machine_type': smallest_machine_type(CLOUD)} - builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) with pytest.raises(aiohttp.client.ClientResponseError, match='resource requests.*unsatisfiable'): - builder.submit() + bb.submit() def test_out_of_memory(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '10Gi'} - j = builder.create_job('python:3.6-slim-stretch', ['python', '-c', 'x = "a" * 1000**3'], resources=resources) - b = builder.submit() + j = bb.create_job('python:3.6-slim-stretch', ['python', '-c', 'x = "a" * 1000**3'], resources=resources) + b = bb.submit() status = j.wait() assert j._get_out_of_memory(status, 'main'), str((status, b.debug_info())) def test_out_of_storage(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '5Gi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 100GiB /foo'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 100GiB /foo'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -170,12 +170,12 @@ def test_out_of_storage(client: BatchClient): def test_quota_applies_to_volume(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '5Gi'} - j = builder.create_job( + j = bb.create_job( os.environ['HAIL_VOLUME_IMAGE'], ['/bin/sh', '-c', 'fallocate -l 100GiB /data/foo'], resources=resources ) - b = builder.submit() + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -183,28 +183,28 @@ def test_quota_applies_to_volume(client: BatchClient): def test_quota_shared_by_io_and_rootfs(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '10Gi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 7GiB /foo'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 7GiB /foo'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '10Gi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 7GiB /io/foo'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 7GiB /io/foo'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '10Gi'} - j = builder.create_job( + j = bb.create_job( DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'fallocate -l 7GiB /foo; fallocate -l 7GiB /io/foo'], resources=resources, ) - b = builder.submit() + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -212,28 +212,28 @@ def test_quota_shared_by_io_and_rootfs(client: BatchClient): def test_nonzero_storage(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '20Gi'} - j = builder.create_job(UBUNTU_IMAGE, ['/bin/sh', '-c', 'true'], resources=resources) - b = builder.submit() + j = bb.create_job(UBUNTU_IMAGE, ['/bin/sh', '-c', 'true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) @skip_in_azure() def test_attached_disk(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '10M', 'storage': '400Gi'} - j = builder.create_job(UBUNTU_IMAGE, ['/bin/sh', '-c', 'df -h; fallocate -l 390GiB /io/foo'], resources=resources) - b = builder.submit() + j = bb.create_job(UBUNTU_IMAGE, ['/bin/sh', '-c', 'df -h; fallocate -l 390GiB /io/foo'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) def test_cwd_from_image_workdir(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(os.environ['HAIL_WORKDIR_IMAGE'], ['/bin/sh', '-c', 'pwd']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(os.environ['HAIL_WORKDIR_IMAGE'], ['/bin/sh', '-c', 'pwd']) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) job_log = j.log() @@ -241,8 +241,8 @@ def test_cwd_from_image_workdir(client: BatchClient): def test_unsubmitted_state(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) with pytest.raises(ValueError): j.batch_id @@ -257,20 +257,20 @@ def test_unsubmitted_state(client: BatchClient): with pytest.raises(ValueError): j.wait() - builder.submit() + bb.submit() with pytest.raises(ValueError): - builder.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) + bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) def test_list_batches(client: BatchClient): tag = secrets.token_urlsafe(64) - b1 = client.create_batch(attributes={'tag': tag, 'name': 'b1'}) - b1.create_job(DOCKER_ROOT_IMAGE, ['sleep', '3600']) - b1 = b1.submit() + bb1 = client.create_batch(attributes={'tag': tag, 'name': 'b1'}) + bb1.create_job(DOCKER_ROOT_IMAGE, ['sleep', '3600']) + b1 = bb1.submit() - b2 = client.create_batch(attributes={'tag': tag, 'name': 'b2'}) - b2.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) - b2 = b2.submit() + bb2 = client.create_batch(attributes={'tag': tag, 'name': 'b2'}) + bb2.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) + b2 = bb2.submit() batch_id_test_universe = {b1.id, b2.id} @@ -310,13 +310,13 @@ def assert_batch_ids(expected: Set[int], q=None): def test_list_jobs(client: BatchClient): - b = client.create_batch() - j_success = b.create_job(DOCKER_ROOT_IMAGE, ['true']) - j_failure = b.create_job(DOCKER_ROOT_IMAGE, ['false']) - j_error = b.create_job(DOCKER_ROOT_IMAGE, ['sleep 5'], attributes={'tag': 'bar'}) - j_running = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '1800'], attributes={'tag': 'foo'}) + bb = client.create_batch() + j_success = bb.create_job(DOCKER_ROOT_IMAGE, ['true']) + j_failure = bb.create_job(DOCKER_ROOT_IMAGE, ['false']) + j_error = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep 5'], attributes={'tag': 'bar'}) + j_running = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '1800'], attributes={'tag': 'foo'}) - b = b.submit() + b = bb.submit() j_success.wait() j_failure.wait() j_error.wait() @@ -337,26 +337,26 @@ def assert_job_ids(expected, q=None): def test_include_jobs(client: BatchClient): - b1 = client.create_batch() - for i in range(2): - b1.create_job(DOCKER_ROOT_IMAGE, ['true']) - b1 = b1.submit() + bb1 = client.create_batch() + for _ in range(2): + bb1.create_job(DOCKER_ROOT_IMAGE, ['true']) + b1 = bb1.submit() s = b1.status() assert 'jobs' not in s, str((s, b1.debug_info())) def test_fail(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['false']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['false']) + b = bb.submit() status = j.wait() assert j._get_exit_code(status, 'main') == 1, str((status, b.debug_info())) def test_unknown_image(client: BatchClient): - b = client.create_batch() - j = b.create_job(f'{DOCKER_PREFIX}/does-not-exist', ['echo', 'test']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(f'{DOCKER_PREFIX}/does-not-exist', ['echo', 'test']) + b = bb.submit() status = j.wait() try: assert j._get_exit_code(status, 'main') is None @@ -367,10 +367,26 @@ def test_unknown_image(client: BatchClient): raise AssertionError(str((status, b.debug_info())), e) -def test_running_job_log_and_status(client: BatchClient): +@skip_in_azure +def test_invalid_gcr(client: BatchClient): b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '300']) + # GCP projects can't be strictly numeric + j = b.create_job(f'gcr.io/1/does-not-exist', ['echo', 'test']) b = b.submit() + status = j.wait() + try: + assert j._get_exit_code(status, 'main') is None + assert status['status']['container_statuses']['main']['short_error'] == 'image repository is invalid', str( + (status, b.debug_info()) + ) + except Exception as e: + raise AssertionError(str((status, b.debug_info())), e) + + +def test_running_job_log_and_status(client: BatchClient): + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '300']) + b = bb.submit() while True: if j.status()['state'] == 'Running' or j.is_complete(): @@ -382,9 +398,9 @@ def test_running_job_log_and_status(client: BatchClient): def test_deleted_job_log(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'test']) + b = bb.submit() j.wait() b.delete() @@ -398,9 +414,9 @@ def test_deleted_job_log(client: BatchClient): def test_delete_batch(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) + b = bb.submit() b.delete() # verify doesn't exist @@ -414,9 +430,9 @@ def test_delete_batch(client: BatchClient): def test_cancel_batch(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) + b = bb.submit() status = j.status() assert status['state'] in ('Ready', 'Running'), str((status, b.debug_info())) @@ -448,9 +464,9 @@ def test_get_nonexistent_job(client: BatchClient): def test_get_job(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['true']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true']) + b = bb.submit() j2 = client.get_job(*j.id) status2 = j2.status() @@ -458,11 +474,11 @@ def test_get_job(client: BatchClient): def test_batch(client: BatchClient): - b = client.create_batch() - j1 = b.create_job(DOCKER_ROOT_IMAGE, ['false']) - j2 = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '1']) - j3 = b.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) - b = b.submit() + bb = client.create_batch() + j1 = bb.create_job(DOCKER_ROOT_IMAGE, ['false']) + j2 = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '1']) + bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) + b = bb.submit() j1.wait() j2.wait() @@ -482,31 +498,31 @@ def test_batch(client: BatchClient): def test_batch_status(client: BatchClient): - b1 = client.create_batch() - b1.create_job(DOCKER_ROOT_IMAGE, ['true']) - b1 = b1.submit() + bb1 = client.create_batch() + bb1.create_job(DOCKER_ROOT_IMAGE, ['true']) + b1 = bb1.submit() b1.wait() b1s = b1.status() assert b1s['complete'] and b1s['state'] == 'success', str((b1s, b1.debug_info())) - b2 = client.create_batch() - b2.create_job(DOCKER_ROOT_IMAGE, ['false']) - b2.create_job(DOCKER_ROOT_IMAGE, ['true']) - b2 = b2.submit() + bb2 = client.create_batch() + bb2.create_job(DOCKER_ROOT_IMAGE, ['false']) + bb2.create_job(DOCKER_ROOT_IMAGE, ['true']) + b2 = bb2.submit() b2.wait() b2s = b2.status() assert b2s['complete'] and b2s['state'] == 'failure', str((b2s, b2.debug_info())) - b3 = client.create_batch() - b3.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) - b3 = b3.submit() + bb3 = client.create_batch() + bb3.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) + b3 = bb3.submit() b3s = b3.status() assert not b3s['complete'] and b3s['state'] == 'running', str((b3s, b3.debug_info())) b3.cancel() - b4 = client.create_batch() - b4.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) - b4 = b4.submit() + bb4 = client.create_batch() + bb4.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30']) + b4 = bb4.submit() b4.cancel() b4.wait() b4s = b4.status() @@ -514,9 +530,9 @@ def test_batch_status(client: BatchClient): def test_log_after_failing_job(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'echo test; exit 127']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'echo test; exit 127']) + b = bb.submit() status = j.wait() assert 'attributes' not in status, str((status, b.debug_info())) assert status['state'] == 'Failed', str((status, b.debug_info())) @@ -529,9 +545,9 @@ def test_log_after_failing_job(client: BatchClient): def test_long_log_line(client: BatchClient): - b = client.create_batch() - j = b.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'for _ in {0..70000}; do echo -n a; done']) - b = b.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['/bin/sh', '-c', 'for _ in {0..70000}; do echo -n a; done']) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) @@ -569,28 +585,28 @@ def test_authorized_users_only(): def test_cloud_image(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(os.environ['HAIL_CURL_IMAGE'], ['echo', 'test']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(os.environ['HAIL_CURL_IMAGE'], ['echo', 'test']) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) def test_service_account(client: BatchClient): - b = client.create_batch() - j = b.create_job( + bb = client.create_batch() + j = bb.create_job( os.environ['CI_UTILS_IMAGE'], ['/bin/sh', '-c', 'kubectl version'], service_account={'namespace': NAMESPACE, 'name': 'test-batch-sa'}, ) - b = b.submit() + b = bb.submit() status = j.wait() assert j._get_exit_code(status, 'main') == 0, str((status, b.debug_info())) def test_port(client: BatchClient): - builder = client.create_batch() - j = builder.create_job( + bb = client.create_batch() + bb.create_job( DOCKER_ROOT_IMAGE, [ 'bash', @@ -602,15 +618,15 @@ def test_port(client: BatchClient): ], port=5000, ) - b = builder.submit() + b = bb.submit() batch = b.wait() assert batch['state'] == 'success', str((batch, b.debug_info())) def test_timeout(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30'], timeout=5) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['sleep', '30'], timeout=5) + b = bb.submit() status = j.wait() assert status['state'] == 'Error', str((status, b.debug_info())) error_msg = j._get_error(status, 'main') @@ -619,10 +635,10 @@ def test_timeout(client: BatchClient): def test_client_max_size(client: BatchClient): - builder = client.create_batch() - for i in range(4): - builder.create_job(DOCKER_ROOT_IMAGE, ['echo', 'a' * (900 * 1024)]) - builder.submit() + bb = client.create_batch() + for _ in range(4): + bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'a' * (900 * 1024)]) + bb.submit() def test_restartable_insert(client: BatchClient): @@ -637,12 +653,12 @@ def every_third_time(): with FailureInjectingClientSession(every_third_time) as session: client = BatchClient('test', session=session) - builder = client.create_batch() + bb = client.create_batch() for _ in range(9): - builder.create_job(DOCKER_ROOT_IMAGE, ['echo', 'a']) + bb.create_job(DOCKER_ROOT_IMAGE, ['echo', 'a']) - b = builder.submit(max_bunch_size=1) + b = bb.submit(max_bunch_size=1) b = client.get_batch(b.id) # get a batch untainted by the FailureInjectingClientSession status = b.wait() assert status['state'] == 'success', str((status, b.debug_info())) @@ -652,10 +668,10 @@ def every_third_time(): def test_create_idempotence(client: BatchClient): token = secrets.token_urlsafe(32) - builder1 = client.create_batch(token=token) - builder2 = client.create_batch(token=token) - b1 = builder1._open_batch() - b2 = builder2._open_batch() + bb1 = client.create_batch(token=token) + bb2 = client.create_batch(token=token) + b1 = bb1._open_batch() + b2 = bb2._open_batch() assert b1.id == b2.id @@ -700,11 +716,11 @@ def test_batch_create_validation(): def test_duplicate_parents(client: BatchClient): - batch = client.create_batch() - head = batch.create_job(DOCKER_ROOT_IMAGE, command=['echo', 'head']) - batch.create_job(DOCKER_ROOT_IMAGE, command=['echo', 'tail'], parents=[head, head]) + bb = client.create_batch() + head = bb.create_job(DOCKER_ROOT_IMAGE, command=['echo', 'head']) + bb.create_job(DOCKER_ROOT_IMAGE, command=['echo', 'tail'], parents=[head, head]) try: - batch = batch.submit() + batch = bb.submit() except aiohttp.ClientResponseError as e: assert e.status == 400 else: @@ -713,11 +729,9 @@ def test_duplicate_parents(client: BatchClient): @skip_in_azure() def test_verify_no_access_to_google_metadata_server(client: BatchClient): - builder = client.create_batch() - j = builder.create_job( - os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', 'metadata.google.internal', '--max-time', '10'] - ) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', 'metadata.google.internal', '--max-time', '10']) + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -725,9 +739,9 @@ def test_verify_no_access_to_google_metadata_server(client: BatchClient): def test_verify_no_access_to_metadata_server(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', '169.254.169.254', '--max-time', '10']) - builder.submit() + bb = client.create_batch() + j = bb.create_job(os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', '169.254.169.254', '--max-time', '10']) + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -735,7 +749,7 @@ def test_verify_no_access_to_metadata_server(client: BatchClient): def test_submit_batch_in_job(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() remote_tmpdir = get_user_config().get('batch', 'remote_tmpdir') script = f'''import hailtop.batch as hb backend = hb.ServiceBackend("test", remote_tmpdir="{remote_tmpdir}") @@ -745,12 +759,12 @@ def test_submit_batch_in_job(client: BatchClient): b.run() backend.close() ''' - j = builder.create_job( + j = bb.create_job( os.environ['HAIL_HAIL_BASE_IMAGE'], ['/bin/bash', '-c', f'''python3 -c \'{script}\''''], mount_tokens=True, ) - b = builder.submit() + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) @@ -766,8 +780,8 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient): backend.close() ''' - builder = client.create_batch() - j = builder.create_job( + bb = client.create_batch() + j = bb.create_job( os.environ['HAIL_HAIL_BASE_IMAGE'], [ '/bin/bash', @@ -779,7 +793,7 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient): ], mount_tokens=True, ) - b = builder.submit() + b = bb.submit() status = j.wait() if NAMESPACE == 'default': assert status['state'] == 'Success', str((status, b.debug_info())) @@ -787,8 +801,8 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient): assert status['state'] == 'Failed', str((status, b.debug_info())) assert "Please log in" in j.log()['main'], (str(j.log()['main']), status) - builder = client.create_batch() - j = builder.create_job( + bb = client.create_batch() + j = bb.create_job( os.environ['HAIL_HAIL_BASE_IMAGE'], [ '/bin/bash', @@ -800,7 +814,7 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient): ], mount_tokens=True, ) - b = builder.submit() + b = bb.submit() status = j.wait() if NAMESPACE == 'default': assert status['state'] == 'Success', str((status, b.debug_info())) @@ -812,7 +826,7 @@ def test_cant_submit_to_default_with_other_ns_creds(client: BatchClient): def test_cannot_contact_other_internal_ips(client: BatchClient): internal_ips = [f'10.128.0.{i}' for i in (10, 11, 12)] - builder = client.create_batch() + bb = client.create_batch() script = f''' if [ "$HAIL_BATCH_WORKER_IP" != "{internal_ips[0]}" ] && ! grep -Fq {internal_ips[0]} /etc/hosts; then OTHER_IP={internal_ips[0]} @@ -824,8 +838,8 @@ def test_cannot_contact_other_internal_ips(client: BatchClient): curl -fsSL -m 5 $OTHER_IP ''' - j = builder.create_job(os.environ['HAIL_CURL_IMAGE'], ['/bin/bash', '-c', script], port=5000) - b = builder.submit() + j = bb.create_job(os.environ['HAIL_CURL_IMAGE'], ['/bin/bash', '-c', script], port=5000) + b = bb.submit() status = j.wait() assert status['state'] == 'Failed', str((status, b.debug_info())) job_log = j.log() @@ -836,7 +850,7 @@ def test_cannot_contact_other_internal_ips(client: BatchClient): def test_can_use_google_credentials(client: BatchClient): token = os.environ["HAIL_TOKEN"] remote_tmpdir = get_user_config().get('batch', 'remote_tmpdir') - builder = client.create_batch() + bb = client.create_batch() script = f'''import hail as hl import secrets attempt_token = secrets.token_urlsafe(5) @@ -844,10 +858,10 @@ def test_can_use_google_credentials(client: BatchClient): hl.utils.range_table(10).write(location) hl.read_table(location).show() ''' - j = builder.create_job( + j = bb.create_job( os.environ['HAIL_HAIL_BASE_IMAGE'], ['/bin/bash', '-c', f'python3 -c >out 2>err \'{script}\'; cat out err'] ) - b = builder.submit() + b = bb.submit() status = j.wait() assert status['state'] == 'Success', f'{j.log(), status}' expected_log = '''+-------+ @@ -872,25 +886,25 @@ def test_can_use_google_credentials(client: BatchClient): def test_user_authentication_within_job(client: BatchClient): - batch = client.create_batch() + bb = client.create_batch() cmd = ['bash', '-c', 'hailctl auth user'] - no_token = batch.create_job(os.environ['CI_UTILS_IMAGE'], cmd, mount_tokens=False) - b = batch.submit() + no_token = bb.create_job(os.environ['CI_UTILS_IMAGE'], cmd, mount_tokens=False) + b = bb.submit() no_token_status = no_token.wait() assert no_token_status['state'] == 'Failed', str((no_token_status, b.debug_info())) def test_verify_access_to_public_internet(client: BatchClient): - builder = client.create_batch() - j = builder.create_job(os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', 'example.com']) - b = builder.submit() + bb = client.create_batch() + j = bb.create_job(os.environ['HAIL_CURL_IMAGE'], ['curl', '-fsSL', 'example.com']) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) def test_verify_can_tcp_to_localhost(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() script = ''' set -e nc -l -p 5000 & @@ -899,8 +913,8 @@ def test_verify_can_tcp_to_localhost(client: BatchClient): '''.lstrip( '\n' ) - j = builder.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) - b = builder.submit() + j = bb.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) job_log = j.log() @@ -908,7 +922,7 @@ def test_verify_can_tcp_to_localhost(client: BatchClient): def test_verify_can_tcp_to_127_0_0_1(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() script = ''' set -e nc -l -p 5000 & @@ -917,8 +931,8 @@ def test_verify_can_tcp_to_127_0_0_1(client: BatchClient): '''.lstrip( '\n' ) - j = builder.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) - b = builder.submit() + j = bb.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/bash', '-c', script]) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) job_log = j.log() @@ -926,7 +940,7 @@ def test_verify_can_tcp_to_127_0_0_1(client: BatchClient): def test_verify_can_tcp_to_self_ip(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() script = ''' set -e nc -l -p 5000 & @@ -935,8 +949,8 @@ def test_verify_can_tcp_to_self_ip(client: BatchClient): '''.lstrip( '\n' ) - j = builder.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/sh', '-c', script]) - b = builder.submit() + j = bb.create_job(os.environ['HAIL_NETCAT_UBUNTU_IMAGE'], command=['/bin/sh', '-c', script]) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) job_log = j.log() @@ -944,12 +958,12 @@ def test_verify_can_tcp_to_self_ip(client: BatchClient): def test_verify_private_network_is_restricted(client: BatchClient): - builder = client.create_batch() - builder.create_job( + bb = client.create_batch() + bb.create_job( os.environ['HAIL_CURL_IMAGE'], command=['curl', 'internal.hail', '--connect-timeout', '60'], network='private' ) try: - builder.submit() + bb.submit() except aiohttp.ClientResponseError as err: assert err.status == 400 assert 'unauthorized network private' in err.message @@ -958,90 +972,90 @@ def test_verify_private_network_is_restricted(client: BatchClient): def test_pool_highmem_instance(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': 'highmem'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'highmem' in status['status']['worker'], str((status, b.debug_info())) def test_pool_highmem_instance_cheapest(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '1', 'memory': '5Gi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'highmem' in status['status']['worker'], str((status, b.debug_info())) def test_pool_highcpu_instance(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': 'lowmem'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'highcpu' in status['status']['worker'], str((status, b.debug_info())) def test_pool_highcpu_instance_cheapest(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': '50Mi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'highcpu' in status['status']['worker'], str((status, b.debug_info())) def test_pool_standard_instance(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '0.25', 'memory': 'standard'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'standard' in status['status']['worker'], str((status, b.debug_info())) def test_pool_standard_instance_cheapest(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'cpu': '1', 'memory': '2.5Gi'} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'standard' in status['status']['worker'], str((status, b.debug_info())) def test_job_private_instance_preemptible(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'machine_type': smallest_machine_type(CLOUD)} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'job-private' in status['status']['worker'], str((status, b.debug_info())) def test_job_private_instance_nonpreemptible(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'machine_type': smallest_machine_type(CLOUD), 'preemptible': False} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() status = j.wait() assert status['state'] == 'Success', str((status, b.debug_info())) assert 'job-private' in status['status']['worker'], str((status, b.debug_info())) def test_job_private_instance_cancel(client: BatchClient): - builder = client.create_batch() + bb = client.create_batch() resources = {'machine_type': smallest_machine_type(CLOUD)} - j = builder.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) - b = builder.submit() + j = bb.create_job(DOCKER_ROOT_IMAGE, ['true'], resources=resources) + b = bb.submit() delay = 0.1 start = time.time() diff --git a/build.yaml b/build.yaml index a10e292f95f..2ffcf144708 100644 --- a/build.yaml +++ b/build.yaml @@ -2596,7 +2596,7 @@ steps: git config user.name ci git config user.email ci@hail.is git add * && git commit -m "setup repo" - git push + retry git push secrets: - name: hail-ci-0-1-service-account-key namespace: diff --git a/ci/ci/constants.py b/ci/ci/constants.py index 955885db872..3c438a8ae82 100644 --- a/ci/ci/constants.py +++ b/ci/ci/constants.py @@ -27,7 +27,6 @@ def __init__(self, gh_username: str, hail_username: Optional[str] = None, teams: User('konradjk', 'konradk'), User('lfrancioli'), User('lgruen'), - User('mkveerapen'), User('nawatts'), User('patrick-schultz', 'pschultz', [COMPILER_TEAM]), User('pwc2', 'pcumming'), @@ -35,5 +34,5 @@ def __init__(self, gh_username: str, hail_username: Optional[str] = None, teams: User('lgruen', 'lgruensc', []), User('vladsaveliev', 'vsavelye', []), User('illusional', 'mfrankli', []), - User('Aleisha02', 'aleisha'), + User('iris-garden', 'irademac'), ] diff --git a/hail/Makefile b/hail/Makefile index 00b84923387..31f15acb438 100644 --- a/hail/Makefile +++ b/hail/Makefile @@ -104,7 +104,7 @@ services-jvm-test: $(SCALA_BUILD_INFO) $(JAR_SOURCES) $(JAR_TEST_SOURCES) # javac args from compileJava in build.gradle $(BUILD_DEBUG_PREFIX)/%.class: src/debug/scala/%.java @mkdir -p $(BUILD_DEBUG_PREFIX) - $(JAVAC) -d $(BUILD_DEBUG_PREFIX) -Xlint:all -Werror -XDenableSunApiLintControl -Xlint:-sunapi $< + $(JAVAC) -d $(BUILD_DEBUG_PREFIX) -Xlint:all -Werror -XDenableSunApiLintControl -XDignore.symbol.file $< src/main/resources/build-info.properties: env/REVISION env/SHORT_REVISION env/BRANCH src/main/resources/build-info.properties: env/SPARK_VERSION env/HAIL_PIP_VERSION diff --git a/hail/build.gradle b/hail/build.gradle index f4107778544..6ad4c6a1712 100644 --- a/hail/build.gradle +++ b/hail/build.gradle @@ -37,7 +37,7 @@ sourceSets { } compileJava { - options.compilerArgs << "-Xlint:all" << "-Werror" << "-XDenableSunApiLintControl" << "-Xlint:-sunapi" + options.compilerArgs << "-Xlint:all" << "-Werror" << "-XDenableSunApiLintControl" << "-XDignore.symbol.file" } tasks.withType(JavaCompile) { options.fork = true // necessary to make -XDenableSunApiLintControl work @@ -61,7 +61,7 @@ compileScala { "-Xlint:all" << "-Werror" << "-XDenableSunApiLintControl" << - "-Xlint:-sunapi" << + "-XDignore.symbol.file" << "-Xlint:-path" // Apparently we try to find some libraries that aren't always installed scalaCompileOptions.additionalParameters = [ diff --git a/hail/python/dev/pinned-requirements.txt b/hail/python/dev/pinned-requirements.txt index a1c13ad0de9..0075dbbb9f0 100644 --- a/hail/python/dev/pinned-requirements.txt +++ b/hail/python/dev/pinned-requirements.txt @@ -27,7 +27,7 @@ backcall==0.2.0 # via ipython beautifulsoup4==4.11.1 # via nbconvert -black==22.3.0 +black==22.8.0 # via -r python/dev/requirements.txt bleach==5.0.0 # via nbconvert @@ -72,7 +72,7 @@ execnet==1.9.0 # via pytest-xdist fastjsonschema==2.15.3 # via nbformat -filelock==3.7.1 +filelock==3.8.0 # via virtualenv flake8==4.0.1 # via -r python/dev/requirements.txt @@ -94,7 +94,7 @@ importlib-metadata==3.10.1 # pre-commit # pytest # virtualenv -importlib-resources==5.7.1 +importlib-resources==5.9.0 # via jsonschema iniconfig==1.1.1 # via pytest @@ -388,7 +388,7 @@ types-decorator==5.1.7 # via -r python/dev/requirements.txt types-deprecated==1.2.8 # via -r python/dev/requirements.txt -types-python-dateutil==2.8.17 +types-python-dateutil==2.8.19 # via -r python/dev/requirements.txt types-pyyaml==6.0.8 # via -r python/dev/requirements.txt @@ -413,7 +413,7 @@ typing-extensions==4.2.0 # jsonschema # mypy # pylint -urllib3==1.26.9 +urllib3==1.26.12 # via requests virtualenv==20.14.1 # via pre-commit diff --git a/hail/python/dev/requirements.txt b/hail/python/dev/requirements.txt index a238ba72596..1325f9a660d 100644 --- a/hail/python/dev/requirements.txt +++ b/hail/python/dev/requirements.txt @@ -4,7 +4,7 @@ flake8==4.0.1 mypy==0.950 pylint==2.13.5 pre-commit==2.18.1 -black==22.3.0 +black==22.8.0 curlylint==0.13.1 click==8.1.2 isort==5.10.1 diff --git a/hail/python/hail/docs/functions/index.rst b/hail/python/hail/docs/functions/index.rst index a9c93513f58..3d9b0fc635b 100644 --- a/hail/python/hail/docs/functions/index.rst +++ b/hail/python/hail/docs/functions/index.rst @@ -98,6 +98,7 @@ These functions are exposed at the top level of the module, e.g. ``hl.case``. bit_lshift bit_rshift bit_not + bit_count exp expit is_nan diff --git a/hail/python/hail/docs/functions/numeric.rst b/hail/python/hail/docs/functions/numeric.rst index cad65a776e9..90ffe0edfa3 100644 --- a/hail/python/hail/docs/functions/numeric.rst +++ b/hail/python/hail/docs/functions/numeric.rst @@ -14,6 +14,7 @@ Numeric functions bit_lshift bit_rshift bit_not + bit_count exp expit is_nan @@ -60,6 +61,7 @@ Numeric functions .. autofunction:: bit_lshift .. autofunction:: bit_rshift .. autofunction:: bit_not +.. autofunction:: bit_count .. autofunction:: exp .. autofunction:: expit .. autofunction:: is_nan diff --git a/hail/python/hail/expr/__init__.py b/hail/python/hail/expr/__init__.py index 7bf537ea932..1db92ea119f 100644 --- a/hail/python/hail/expr/__init__.py +++ b/hail/python/hail/expr/__init__.py @@ -34,7 +34,7 @@ parse_float, parse_float32, parse_float64, int, int32, int64, parse_int, parse_int32, parse_int64, bool, get_sequence, reverse_complement, is_valid_contig, is_valid_locus, contig_length, liftover, min_rep, uniroot, format, approx_equal, reversed, bit_and, bit_or, - bit_xor, bit_lshift, bit_rshift, bit_not, binary_search, logit, expit, _values_similar, + bit_xor, bit_lshift, bit_rshift, bit_not, bit_count, binary_search, logit, expit, _values_similar, _showstr, _sort_by, _compare, _locus_windows_per_contig, shuffle, _console_log, dnorm, dchisq) __all__ = ['HailType', @@ -243,6 +243,7 @@ 'bit_lshift', 'bit_rshift', 'bit_not', + 'bit_count', 'binary_search', '_values_similar', '_showstr', diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index b95057e358c..3903b8d50b1 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -6193,6 +6193,28 @@ def bit_not(x): return construct_expr(ir.ApplyUnaryPrimOp('~', x._ir), x.dtype, x._indices, x._aggregations) +@typecheck(x=expr_oneof(expr_int32, expr_int64)) +def bit_count(x): + """Count the number of 1s in the in the `two's complement `__ binary representation of `x`. + + Examples + -------- + The binary representation of `7` is `111`, so: + + >>> hl.eval(hl.bit_count(7)) + 3 + + Parameters + ---------- + x : :class:`.Int32Expression` or :class:`.Int64Expression` + + Returns + ---------- + :class:`.Int32Expression` + """ + return construct_expr(ir.ApplyUnaryPrimOp('BitCount', x._ir), tint32, x._indices, x._aggregations) + + @typecheck(array=expr_array(expr_numeric), elem=expr_numeric) def binary_search(array, elem) -> Int32Expression: """Binary search `array` for the insertion point of `elem`. diff --git a/hail/python/hail/ggplot/ggplot.py b/hail/python/hail/ggplot/ggplot.py index 77c13139439..0ce69e1c94e 100644 --- a/hail/python/hail/ggplot/ggplot.py +++ b/hail/python/hail/ggplot/ggplot.py @@ -101,7 +101,9 @@ def copy(self): return GGPlot(self.ht, self.aes, self.geoms[:], self.labels, self.coord_cartesian, self.scales, self.facet) def verify_scales(self): - for geom_idx, geom in enumerate(self.geoms): + for aes_key in self.aes.keys(): + check_scale_continuity(self.scales[aes_key], self.aes[aes_key].dtype, aes_key) + for geom in self.geoms: aesthetic_dict = geom.aes.properties for aes_key in aesthetic_dict.keys(): check_scale_continuity(self.scales[aes_key], aesthetic_dict[aes_key].dtype, aes_key) @@ -160,9 +162,9 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): stat = self.geoms[geom_idx].get_stat() geom_label = make_geom_label(geom_idx) if use_faceting: - agg = hl.agg.group_by(selected.facet, stat.make_agg(combined_mapping, precomputed[geom_label])) + agg = hl.agg.group_by(selected.facet, stat.make_agg(combined_mapping, precomputed[geom_label], self.scales)) else: - agg = stat.make_agg(combined_mapping, precomputed[geom_label]) + agg = stat.make_agg(combined_mapping, precomputed[geom_label], self.scales) aggregators[geom_label] = agg labels_to_stats[geom_label] = stat diff --git a/hail/python/hail/ggplot/scale.py b/hail/python/hail/ggplot/scale.py index 19aa81289a6..89a14637bcc 100644 --- a/hail/python/hail/ggplot/scale.py +++ b/hail/python/hail/ggplot/scale.py @@ -2,8 +2,9 @@ from .geoms import FigureAttribute from hail.context import get_reference +from hail import tstr -from .utils import categorical_strings_to_colors, continuous_nums_to_colors +from .utils import categorical_strings_to_colors, continuous_nums_to_colors, is_continuous_type, is_discrete_type import plotly.express as px import plotly @@ -28,6 +29,9 @@ def is_discrete(self): def is_continuous(self): pass + def valid_dtype(self, dtype): + pass + class PositionScale(Scale): def __init__(self, aesthetic_name, name, breaks, labels): @@ -53,6 +57,9 @@ def apply_to_fig(self, parent, fig_so_far): if self.labels is not None: self.update_axis(fig_so_far)(ticktext=self.labels) + def valid_dtype(self, dtype): + return True + class PositionScaleGenomic(PositionScale): def __init__(self, aesthetic_name, reference_genome, name=None): @@ -135,6 +142,9 @@ def is_discrete(self): def is_continuous(self): return True + def valid_dtype(self, dtype): + return is_continuous_type(dtype) + class ScaleDiscrete(Scale): def __init__(self, aesthetic_name): @@ -149,6 +159,9 @@ def is_discrete(self): def is_continuous(self): return False + def valid_dtype(self, dtype): + return is_discrete_type(dtype) + class ScaleColorManual(ScaleDiscrete): @@ -226,9 +239,9 @@ def transform(df): return transform -# Legend names messed up for scale color identity -class ScaleColorDiscreteIdentity(ScaleDiscrete): - pass +class ScaleColorContinuousIdentity(ScaleContinuous): + def valid_dtype(self, dtype): + return dtype == tstr def scale_x_log10(name=None): @@ -439,7 +452,7 @@ def scale_color_identity(): :class:`.FigureAttribute` The scale to be applied. """ - return ScaleColorDiscreteIdentity("color") + return ScaleColorContinuousIdentity("color") def scale_color_manual(*, values): @@ -471,7 +484,7 @@ def scale_fill_discrete(): def scale_fill_continuous(): - """The default discrete fill scale. This linearly interpolates colors between the min and max observed values. + """The default continuous fill scale. This linearly interpolates colors between the min and max observed values. Returns ------- @@ -489,7 +502,7 @@ def scale_fill_identity(): :class:`.FigureAttribute` The scale to be applied. """ - return ScaleColorDiscreteIdentity("fill") + return ScaleColorContinuousIdentity("fill") def scale_fill_hue(): diff --git a/hail/python/hail/ggplot/stats.py b/hail/python/hail/ggplot/stats.py index 3ef02a78810..3adb51f71d9 100644 --- a/hail/python/hail/ggplot/stats.py +++ b/hail/python/hail/ggplot/stats.py @@ -5,12 +5,12 @@ import hail as hl from hail.utils.java import warning -from .utils import should_use_for_grouping +from .utils import should_use_scale_for_grouping class Stat: @abc.abstractmethod - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): return @abc.abstractmethod @@ -23,9 +23,9 @@ def get_precomputes(self, mapping): class StatIdentity(Stat): - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_for_grouping(aes_key, mapping[aes_key].dtype)} + if should_use_scale_for_grouping(scales[aes_key])} non_grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() if aes_key not in grouping_variables} return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.collect(hl.struct(**non_grouping_variables))) @@ -49,13 +49,13 @@ class StatFunction(StatIdentity): def __init__(self, fun): self.fun = fun - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): with_y_value = mapping.annotate(y=self.fun(mapping.x)) - return super().make_agg(with_y_value, precomputed) + return super().make_agg(with_y_value, precomputed, scales) class StatNone(Stat): - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): return hl.agg.take(hl.struct(), 0) def listify(self, agg_result): @@ -63,9 +63,9 @@ def listify(self, agg_result): class StatCount(Stat): - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_for_grouping(aes_key, mapping[aes_key].dtype)} + if should_use_scale_for_grouping(scales[aes_key])} if "weight" in mapping: return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.counter(mapping["x"], weight=mapping["weight"])) return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.group_by(mapping["x"], hl.agg.count())) @@ -102,9 +102,9 @@ def get_precomputes(self, mapping): precomputes["max_val"] = hl.agg.max(mapping.x) return hl.struct(**precomputes) - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_for_grouping(aes_key, mapping[aes_key].dtype)} + if should_use_scale_for_grouping(scales[aes_key])} start = self.min_val if self.min_val is not None else precomputed.min_val end = self.max_val if self.max_val is not None else precomputed.max_val @@ -137,9 +137,9 @@ class StatCDF(Stat): def __init__(self, k): self.k = k - def make_agg(self, mapping, precomputed): + def make_agg(self, mapping, precomputed, scales): grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_for_grouping(aes_key, mapping[aes_key].dtype)} + if should_use_scale_for_grouping(scales[aes_key])} return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.approx_cdf(mapping["x"], self.k)) def listify(self, agg_result): diff --git a/hail/python/hail/ggplot/utils.py b/hail/python/hail/ggplot/utils.py index 51135d87ca8..5a1c21842b5 100644 --- a/hail/python/hail/ggplot/utils.py +++ b/hail/python/hail/ggplot/utils.py @@ -3,11 +3,8 @@ def check_scale_continuity(scale, dtype, aes_key): - - if scale.is_discrete() and not is_discrete_type(dtype): - raise ValueError(f"Aesthetic {aes_key} has discrete scale but not a discrete type.") - if scale.is_continuous() and not is_continuous_type(dtype): - raise ValueError(f"Aesthetic {aes_key} has continuous scale but not a continuous type.") + if not scale.valid_dtype(dtype): + raise ValueError(f"Invalid scale for aesthetic {aes_key} of type {dtype}") def is_genomic_type(dtype): @@ -25,7 +22,7 @@ def is_discrete_type(dtype): excluded_from_grouping = {"x", "tooltip", "label"} -def should_use_for_grouping(name, type): +def should_use_for_grouping(name, type, scale): return (name not in excluded_from_grouping) and is_discrete_type(type) diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index a31d1a5bd9c..8cb34ebb4a9 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -54,7 +54,7 @@ BlockMatrixBinaryReader, BlockMatrixPersistReader from .matrix_writer import MatrixWriter, MatrixNativeWriter, MatrixVCFWriter, \ MatrixGENWriter, MatrixBGENWriter, MatrixPLINKWriter, MatrixNativeMultiWriter, MatrixBlockMatrixWriter -from .table_writer import TableWriter, TableNativeWriter, TableTextWriter +from .table_writer import (TableWriter, TableNativeWriter, TableTextWriter, TableNativeFanoutWriter) from .blockmatrix_writer import BlockMatrixWriter, BlockMatrixNativeWriter, \ BlockMatrixBinaryWriter, BlockMatrixRectanglesWriter, \ BlockMatrixMultiWriter, BlockMatrixBinaryMultiWriter, \ @@ -310,5 +310,6 @@ 'AvroTableReader', 'TableWriter', 'TableNativeWriter', - 'TableTextWriter' + 'TableTextWriter', + 'TableNativeFanoutWriter' ] diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 9dc161fffd5..5ba4d00be79 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -537,7 +537,10 @@ def _eq(self, other): def _compute_type(self, env, agg_env, deep_typecheck): self.x.compute_type(env, agg_env, deep_typecheck) - return self.x.typ + if self.op == 'BitCount': + return tint32 + else: + return self.x.typ class ApplyComparisonOp(IR): diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index c9c41379915..4480aeac166 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -575,7 +575,7 @@ def __init__(self, child): self.child = child def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - child = self.child(row_uid_field_name, col_uid_field_name) + child = self.child.handle_randomness(row_uid_field_name, col_uid_field_name) result = MatrixCollectColsByKey(child) if col_uid_field_name is not None: col = ir.Ref('sa', result.typ.col_type) diff --git a/hail/python/hail/ir/table_writer.py b/hail/python/hail/ir/table_writer.py index d609cb7ecba..891b0401ff1 100644 --- a/hail/python/hail/ir/table_writer.py +++ b/hail/python/hail/ir/table_writer.py @@ -1,6 +1,6 @@ import abc import json -from ..typecheck import typecheck_method, nullable +from ..typecheck import typecheck_method, nullable, sequenceof from ..utils.misc import escape_str from .export_type import ExportType @@ -73,3 +73,35 @@ def __eq__(self, other): other.header == self.header and \ other.export_type == self.export_type and \ other.delimiter == self.delimiter + + +class TableNativeFanoutWriter(TableWriter): + @typecheck_method(path=str, + fields=sequenceof(str), + overwrite=bool, + stage_locally=bool, + codec_spec=nullable(str)) + def __init__(self, path, fields, overwrite, stage_locally, codec_spec): + super(TableNativeFanoutWriter, self).__init__() + self.path = path + self.fields = fields + self.overwrite = overwrite + self.stage_locally = stage_locally + self.codec_spec = codec_spec + + def render(self): + writer = {'name': 'TableNativeFanoutWriter', + 'path': self.path, + 'fields': self.fields, + 'overwrite': self.overwrite, + 'stageLocally': self.stage_locally, + 'codecSpecJSONStr': self.codec_spec} + return escape_str(json.dumps(writer)) + + def __eq__(self, other): + return isinstance(other, TableNativeWriter) and \ + other.path == self.path and \ + other.fields == self.fields and \ + other.overwrite == self.overwrite and \ + other.stage_locally == self.stage_locally and \ + other.codec_spec == self.codec_spec diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index 87386118c35..8e7375242c1 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -3,7 +3,7 @@ import pandas import numpy as np import pyspark -from typing import Optional, Dict, Callable +from typing import Optional, Dict, Callable, Sequence from hail.expr.expressions import Expression, StructExpression, \ BooleanExpression, expr_struct, expr_any, expr_bool, analyze, Indices, \ @@ -1334,6 +1334,137 @@ def write(self, output: str, overwrite=False, stage_locally: bool = False, Env.backend().execute(ir.TableWrite(self._tir, ir.TableNativeWriter(output, overwrite, stage_locally, _codec_spec))) + @typecheck_method(output=str, + fields=sequenceof(str), + overwrite=bool, + stage_locally=bool, + _codec_spec=nullable(str)) + def write_many(self, + output: str, + fields: Sequence[str], + *, + overwrite: bool = False, + stage_locally: bool = False, + _codec_spec: Optional[str] = None): + """Write fields to distinct tables. + + Examples + -------- + + >>> t = hl.utils.range_table(10) + >>> t = t.annotate(a = t.idx, b = t.idx * t.idx, c = hl.str(t.idx)) + >>> t.write_many('output', fields=('a', 'b', 'c')) + >>> hl.read_table('output/a').describe() + ---------------------------------------- + Global fields: + None + ---------------------------------------- + Row fields: + 'a': int32 + 'idx': int32 + ---------------------------------------- + Key: ['idx'] + ---------------------------------------- + >>> hl.read_table('output/a').show() + +-------+-------+ + | a | idx | + +-------+-------+ + | int32 | int32 | + +-------+-------+ + | 0 | 0 | + | 1 | 1 | + | 2 | 2 | + | 3 | 3 | + | 4 | 4 | + | 5 | 5 | + | 6 | 6 | + | 7 | 7 | + | 8 | 8 | + | 9 | 9 | + +-------+-------+ + >>> hl.read_table('output/b').describe() + ---------------------------------------- + Global fields: + None + ---------------------------------------- + Row fields: + 'b': int32 + 'idx': int32 + ---------------------------------------- + Key: ['idx'] + ---------------------------------------- + >>> hl.read_table('output/b').show() + +-------+-------+ + | b | idx | + +-------+-------+ + | int32 | int32 | + +-------+-------+ + | 0 | 0 | + | 1 | 1 | + | 4 | 2 | + | 9 | 3 | + | 16 | 4 | + | 25 | 5 | + | 36 | 6 | + | 49 | 7 | + | 64 | 8 | + | 81 | 9 | + +-------+-------+ + >>> hl.read_table('output/c').describe() + ---------------------------------------- + Global fields: + None + ---------------------------------------- + Row fields: + 'c': str + 'idx': int32 + ---------------------------------------- + Key: ['idx'] + ---------------------------------------- + >>> hl.read_table('output/c').show() + +-----+-------+ + | c | idx | + +-----+-------+ + | str | int32 | + +-----+-------+ + | "0" | 0 | + | "1" | 1 | + | "2" | 2 | + | "3" | 3 | + | "4" | 4 | + | "5" | 5 | + | "6" | 6 | + | "7" | 7 | + | "8" | 8 | + | "9" | 9 | + +-----+-------+ + + .. include:: _templates/write_warning.rst + + See Also + -------- + :func:`.read_table` + + Parameters + ---------- + output : str + Path at which to write. + fields : list of str + The fields to write. + stage_locally: bool + If ``True``, major output will be written to temporary local storage + before being copied to ``output``. + overwrite : bool + If ``True``, overwrite an existing file at the destination. + """ + + Env.backend().execute( + ir.TableWrite( + self._tir, + ir.TableNativeFanoutWriter(output, fields, overwrite, stage_locally, _codec_spec) + ) + ) + def _show(self, n, width, truncate, types): return Table._Show(self, n, width, truncate, types) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py index b4dba4a48ee..35b2cf3c68e 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py @@ -10,7 +10,7 @@ import aiohttp from hailtop.utils import ( secret_alnum_string, OnlineBoundedGather2, - TransientError, retry_transient_errors_wrapper) + TransientError, retry_transient_errors) from hailtop.aiotools.fs import (FileStatus, FileListEntry, ReadableStream, WritableStream, AsyncFS, AsyncFSURL, AsyncFSFactory, FileAndDirectoryError, MultiPartCreate, UnexpectedEOFError) @@ -139,8 +139,7 @@ def _range_upper(range): assert len(split_range) == 2 return int(split_range[1]) - @retry_transient_errors_wrapper - async def _write_chunk(self): + async def _write_chunk_1(self): assert not self._done assert self._closed or self._write_buffer.size() >= self._chunk_size @@ -243,6 +242,9 @@ async def _write_chunk(self): resp.raise_for_status() assert False + async def _write_chunk(self): + await retry_transient_errors(self._write_chunk_1) + async def write(self, b): assert not self._closed assert self._write_buffer.size() < self._chunk_size @@ -269,12 +271,10 @@ def __init__(self, resp: aiohttp.ClientResponse): # https://docs.aiohttp.org/en/stable/streams.html#aiohttp.StreamReader.read # Read up to n bytes. If n is not provided, or set to -1, read until EOF # and return all read bytes. - @retry_transient_errors_wrapper async def read(self, n: int = -1) -> bytes: assert not self._closed and self._content is not None return await self._content.read(n) - @retry_transient_errors_wrapper async def readexactly(self, n: int) -> bytes: assert not self._closed and n >= 0 and self._content is not None try: diff --git a/hail/python/hailtop/utils/__init__.py b/hail/python/hailtop/utils/__init__.py index 5583549ce4e..73e4fd433a2 100644 --- a/hail/python/hailtop/utils/__init__.py +++ b/hail/python/hailtop/utils/__init__.py @@ -3,7 +3,7 @@ grouped, sync_sleep_and_backoff, sleep_and_backoff, is_transient_error, request_retry_transient_errors, request_raise_transient_errors, collect_agen, retry_all_errors, retry_transient_errors, - retry_transient_errors_with_debug_string, retry_transient_errors_wrapper, retry_long_running, run_if_changed, + retry_transient_errors_with_debug_string, retry_long_running, run_if_changed, run_if_changed_idempotent, LoggingTimer, WaitableSharedPool, RETRY_FUNCTION_SCRIPT, sync_retry_transient_errors, retry_response_returning_functions, first_extant_file, secret_alnum_string, @@ -49,7 +49,6 @@ 'retry_all_errors', 'retry_transient_errors', 'retry_transient_errors_with_debug_string', - 'retry_transient_errors_wrapper', 'retry_long_running', 'run_if_changed', 'run_if_changed_idempotent', diff --git a/hail/python/hailtop/utils/utils.py b/hail/python/hailtop/utils/utils.py index 4c8d0c3d4b7..8b059eec482 100644 --- a/hail/python/hailtop/utils/utils.py +++ b/hail/python/hailtop/utils/utils.py @@ -4,7 +4,6 @@ from types import TracebackType import concurrent import contextlib -import functools import subprocess import traceback import sys @@ -552,6 +551,12 @@ class TransientError(Exception): pass +RETRY_ONCE_BAD_REQUEST_ERROR_MESSAGES = { + 'User project specified in the request is invalid.', + 'Invalid grant: account not found', +} + + def is_retry_once_error(e): # An exception is a "retry once error" if a rare, known bug in a dependency or in a cloud # provider can manifest as this exception *and* that manifestation is indistinguishable from a @@ -562,7 +567,7 @@ def is_retry_once_error(e): and 'azurecr.io' in e.message and 'not found: manifest unknown: ' in e.message) if isinstance(e, hailtop.httpx.ClientResponseError): - return e.status == 400 and 'User project specified in the request is invalid.' in e.body + return e.status == 400 and any(msg in e.body for msg in RETRY_ONCE_BAD_REQUEST_ERROR_MESSAGES) return False @@ -679,6 +684,8 @@ def is_transient_error(e): if isinstance(e, botocore.exceptions.ConnectionClosedError): return True if aiodocker is not None and isinstance(e, aiodocker.exceptions.DockerError): + if e.status == 500 and 'Invalid repository name' in e.message: + return False return e.status in RETRYABLE_HTTP_STATUS_CODES if isinstance(e, TransientError): return True @@ -763,15 +770,6 @@ async def retry_transient_errors_with_debug_string(debug_string: str, f: Callabl delay = await sleep_and_backoff(delay) -def retry_transient_errors_wrapper(f): - """Decorator for `retry_transient_errors`.""" - @functools.wraps(f) - async def wrapper(*args, **kwargs): - return await retry_transient_errors(f, *args, **kwargs) - - return wrapper - - def sync_retry_transient_errors(f, *args, **kwargs): delay = 0.1 errors = 0 diff --git a/hail/python/pinned-requirements.txt b/hail/python/pinned-requirements.txt index 83dd23ab2de..ef35c51b4ea 100644 --- a/hail/python/pinned-requirements.txt +++ b/hail/python/pinned-requirements.txt @@ -225,7 +225,7 @@ typing-extensions==4.2.0 # azure-core # janus # yarl -urllib3==1.26.9 +urllib3==1.26.12 # via # botocore # requests diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index a9d04541c74..f48edcbdd34 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -3616,6 +3616,9 @@ def test_bit_ops_types(self): assert hl.bit_not(1).dtype == hl.tint32 assert hl.bit_not(hl.int64(1)).dtype == hl.tint64 + assert hl.bit_count(1).dtype == hl.tint32 + assert hl.bit_count(hl.int64(1)).dtype == hl.tint32 + def test_bit_shifts(self): assert hl.eval(hl.bit_lshift(hl.int(8), 2)) == 32 assert hl.eval(hl.bit_rshift(hl.int(8), 2)) == 2 diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index f09fef9475e..0098fd0b018 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -385,6 +385,13 @@ def test_collect_cols_by_key(self): hl.Struct(row_idx=2, col_idx=1, bar=[4, 6]), hl.Struct(row_idx=2, col_idx=2, bar=[8, 10, 12])]) + def test_collect_cols_by_key_with_rand(self): + mt = hl.utils.range_matrix_table(3, 3) + mt = mt.annotate_cols(x = hl.rand_norm()) + mt = mt.collect_cols_by_key() + mt = mt.annotate_cols(x = hl.rand_norm()) + mt.cols().collect() + def test_weird_names(self): ds = self.get_mt() exprs = {'a': 5, ' a ': 5, r'\%!^!@#&#&$%#$%': [5], '$': 5, 'ß': 5} @@ -1662,6 +1669,16 @@ def test_filter_locus_position_collect_returns_data(self): assert t.filter(t.locus.position >= 1).collect() == [ hl.utils.Struct(idx=0, locus=hl.genetics.Locus(contig='2', position=1, reference_genome='GRCh37'))] + @fails_service_backend() + @fails_local_backend() + def test_lower_row_agg_init_arg(self): + mt = hl.balding_nichols_model(5, 200, 200) + mt2 = hl.variant_qc(mt) + mt2 = mt2.filter_rows((mt2.variant_qc.AF[0] > 0.05) & (mt2.variant_qc.AF[0] < 0.95)) + mt2 = mt2.sample_rows(.99) + rows = mt2.rows() + mt = mt.semi_join_rows(rows) + hl.hwe_normalized_pca(mt.GT) def test_keys_before_scans(): mt = hl.utils.range_matrix_table(6, 6) diff --git a/hail/python/test/hail/table/test_table.py b/hail/python/test/hail/table/test_table.py index 5c1d1526378..223c2b75e1b 100644 --- a/hail/python/test/hail/table/test_table.py +++ b/hail/python/test/hail/table/test_table.py @@ -903,12 +903,14 @@ def test_indexed_read(self): t = hl.utils.range_table(2000, 10) f = new_temp_file(extension='ht') t.write(f) + t2 = hl.read_table(f, _intervals=[ hl.Interval(start=150, end=250, includes_start=True, includes_end=False), hl.Interval(start=250, end=500, includes_start=True, includes_end=False), ]) self.assertEqual(t2.n_partitions(), 2) self.assertEqual(t2.count(), 350) + self.assertEqual(t2._force_count(), 350) self.assertTrue(t.filter((t.idx >= 150) & (t.idx < 500))._same(t2)) t2 = hl.read_table(f, _intervals=[ @@ -1827,3 +1829,34 @@ def test_to_pandas_nd_array(): df_from_python = pd.DataFrame(python_data) pd.testing.assert_frame_equal(df_from_hail, df_from_python) + + +def test_write_many(): + t = hl.utils.range_table(5) + t = t.annotate(a = t.idx, b = t.idx * t.idx, c = hl.str(t.idx)) + with hl.TemporaryDirectory(ensure_exists=False) as f: + t.write_many(f, fields=('a', 'b', 'c')) + + assert hl.read_table(f + '/a').collect() == [ + hl.Struct(idx=0, a=0), + hl.Struct(idx=1, a=1), + hl.Struct(idx=2, a=2), + hl.Struct(idx=3, a=3), + hl.Struct(idx=4, a=4) + ] + + assert hl.read_table(f + '/b').collect() == [ + hl.Struct(idx=0, b=0), + hl.Struct(idx=1, b=1), + hl.Struct(idx=2, b=4), + hl.Struct(idx=3, b=9), + hl.Struct(idx=4, b=16) + ] + + assert hl.read_table(f + '/c').collect() == [ + hl.Struct(idx=0, c='0'), + hl.Struct(idx=1, c='1'), + hl.Struct(idx=2, c='2'), + hl.Struct(idx=3, c='3'), + hl.Struct(idx=4, c='4') + ] diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index e2be3c864c0..94f964c61ee 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -150,103 +150,6 @@ object HailContext { theContext = null } - def readRowsPartition( - makeDec: (InputStream, HailClassLoader) => Decoder - )(theHailClassLoader: HailClassLoader, - r: Region, - in: InputStream, - metrics: InputMetrics = null - ): Iterator[Long] = - new Iterator[Long] { - private val region = r - - private val trackedIn = new ByteTrackingInputStream(in) - private val dec = - try { - makeDec(trackedIn, theHailClassLoader) - } catch { - case e: Exception => - in.close() - throw e - } - - private var cont: Byte = dec.readByte() - if (cont == 0) - dec.close() - - // can't throw - def hasNext: Boolean = cont != 0 - - def next(): Long = { - // !hasNext => cont == 0 => dec has been closed - if (!hasNext) - throw new NoSuchElementException("next on empty iterator") - - try { - val res = dec.readRegionValue(region) - cont = dec.readByte() - if (metrics != null) { - ExposedMetrics.incrementRecord(metrics) - ExposedMetrics.incrementBytes(metrics, trackedIn.bytesReadAndClear()) - } - - if (cont == 0) - dec.close() - - res - } catch { - case e: Exception => - dec.close() - throw e - } - } - - override def finalize(): Unit = { - dec.close() - } - } - - def readRowsIndexedPartition( - makeDec: (InputStream, HailClassLoader) => Decoder - )(theHailClassLoader: HailClassLoader, - ctx: RVDContext, - in: InputStream, - idxr: IndexReader, - offsetField: Option[String], - bounds: Option[Interval], - metrics: InputMetrics = null - ): Iterator[Long] = - bounds match { - case Some(b) => - new IndexReadIterator(theHailClassLoader, makeDec, ctx.r, in, idxr, offsetField.orNull, b, metrics) - case None => - idxr.close() - HailContext.readRowsPartition(makeDec)(theHailClassLoader, ctx.r, in, metrics) - } - - def readSplitRowsPartition( - theHailClassLoader: HailClassLoader, - fs: BroadcastValue[FS], - mkRowsDec: (InputStream, HailClassLoader) => Decoder, - mkEntriesDec: (InputStream, HailClassLoader) => Decoder, - mkInserter: (HailClassLoader, FS, Int, Region) => AsmFunction3RegionLongLongLong - )(ctx: RVDContext, - isRows: InputStream, - isEntries: InputStream, - idxr: Option[IndexReader], - rowsOffsetField: Option[String], - entriesOffsetField: Option[String], - bounds: Option[Interval], - partIdx: Int, - metrics: InputMetrics = null - ): Iterator[Long] = new MaybeIndexedReadZippedIterator( - is => mkRowsDec(is, theHailClassLoaderForSparkWorkers), - is => mkEntriesDec(is, theHailClassLoaderForSparkWorkers), - mkInserter(theHailClassLoader, fs.value, partIdx, ctx.partitionRegion), - ctx.r, - isRows, isEntries, - idxr.orNull, rowsOffsetField.orNull, entriesOffsetField.orNull, bounds.orNull, metrics) - def pyRemoveIrVector(id: Int) { get.irVectors.remove(id) } @@ -275,123 +178,6 @@ object HailContext { @transient override val partitioner: Option[Partitioner] = optPartitioner } } - - def readRows( - ctx: ExecuteContext, - path: String, - enc: AbstractTypedCodecSpec, - partFiles: Array[String], - requestedType: TStruct - ): (PStruct, ContextRDD[Long]) = { - val fs = ctx.fs - val (pType: PStruct, makeDec) = enc.buildDecoder(ctx, requestedType) - (pType, ContextRDD.weaken(HailContext.readPartitions(fs, path, partFiles, (_, is, m) => Iterator.single(is -> m))) - .cmapPartitions { (ctx, it) => - assert(it.hasNext) - val (is, m) = it.next - assert(!it.hasNext) - HailContext.readRowsPartition(makeDec)(theHailClassLoaderForSparkWorkers, ctx.r, is, m) - }) - } - - def readIndexedRows( - ctx: ExecuteContext, - path: String, - indexSpec: AbstractIndexSpec, - enc: AbstractTypedCodecSpec, - partFiles: Array[String], - bounds: Array[Interval], - requestedType: TStruct - ): (PStruct, ContextRDD[Long]) = { - val (pType: PStruct, makeDec) = enc.buildDecoder(ctx, requestedType) - (pType, ContextRDD.weaken(readIndexedPartitions(ctx, path, indexSpec, partFiles, Some(bounds))) - .cmapPartitions { (ctx, it) => - assert(it.hasNext) - val (is, idxr, bounds, m) = it.next - assert(!it.hasNext) - readRowsIndexedPartition(makeDec)(theHailClassLoaderForSparkWorkers, ctx, is, idxr, indexSpec.offsetField, bounds, m) - }) - } - - def readIndexedPartitions( - ctx: ExecuteContext, - path: String, - indexSpec: AbstractIndexSpec, - partFiles: Array[String], - intervalBounds: Option[Array[Interval]] = None - ): RDD[(InputStream, IndexReader, Option[Interval], InputMetrics)] = { - val idxPath = indexSpec.relPath - val fsBc = ctx.fsBc - val (keyType, annotationType) = indexSpec.types - indexSpec.offsetField.foreach { f => - require(annotationType.asInstanceOf[TStruct].hasField(f)) - require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64) - } - val (leafPType: PStruct, leafDec) = indexSpec.leafCodec.buildDecoder(ctx, indexSpec.leafCodec.encodedVirtualType) - val (intPType: PStruct, intDec) = indexSpec.internalNodeCodec.buildDecoder(ctx, indexSpec.internalNodeCodec.encodedVirtualType) - val mkIndexReader = IndexReaderBuilder.withDecoders(leafDec, intDec, keyType, annotationType, leafPType, intPType) - - new IndexReadRDD(partFiles, intervalBounds, { (p, context) => - val fs = fsBc.value - val idxname = s"$path/$idxPath/${ p.file }.idx" - val filename = s"$path/parts/${ p.file }" - val idxr = mkIndexReader(theHailClassLoaderForSparkWorkers, fs, idxname, 8, SparkTaskContext.get().getRegionPool()) // default cache capacity - val in = fs.open(filename) - (in, idxr, p.bounds, context.taskMetrics().inputMetrics) - }) - } - - - def readRowsSplit( - ctx: ExecuteContext, - pathRows: String, - pathEntries: String, - indexSpecRows: Option[AbstractIndexSpec], - indexSpecEntries: Option[AbstractIndexSpec], - partFiles: Array[String], - bounds: Array[Interval], - makeRowsDec: (InputStream, HailClassLoader) => Decoder, - makeEntriesDec: (InputStream, HailClassLoader) => Decoder, - makeInserter: (HailClassLoader, FS, Int, Region) => AsmFunction3RegionLongLongLong - ): ContextRDD[Long] = { - require(!(indexSpecRows.isEmpty ^ indexSpecEntries.isEmpty)) - val fsBc = ctx.fsBc - - val mkIndexReader = indexSpecRows.map { indexSpec => - val (keyType, annotationType) = indexSpec.types - indexSpec.offsetField.foreach { f => - require(annotationType.asInstanceOf[TStruct].hasField(f)) - require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64) - } - indexSpecEntries.get.offsetField.foreach { f => - require(annotationType.asInstanceOf[TStruct].hasField(f)) - require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64) - } - IndexReaderBuilder.fromSpec(ctx, indexSpec) - } - - val rdd = new IndexReadRDD(partFiles, indexSpecRows.map(_ => bounds), (p, context) => { - val fs = fsBc.value - val idxr = mkIndexReader.map { mk => - val idxname = s"$pathRows/${ indexSpecRows.get.relPath }/${ p.file }.idx" - mk(theHailClassLoaderForSparkWorkers, fs, idxname, 8, SparkTaskContext.get().getRegionPool()) // default cache capacity - } - val inRows = fs.open(s"$pathRows/parts/${ p.file }") - val inEntries = fs.open(s"$pathEntries/parts/${ p.file }") - (inRows, inEntries, idxr, p.bounds, context.taskMetrics().inputMetrics) - }) - - val rowsOffsetField = indexSpecRows.flatMap(_.offsetField) - val entriesOffsetField = indexSpecEntries.flatMap(_.offsetField) - ContextRDD.weaken(rdd).cmapPartitionsWithIndex { (i, ctx, it) => - assert(it.hasNext) - val (isRows, isEntries, idxr, bounds, m) = it.next - assert(!it.hasNext) - HailContext.readSplitRowsPartition(theHailClassLoaderForSparkWorkers, fsBc, makeRowsDec, makeEntriesDec, makeInserter)( - ctx, isRows, isEntries, idxr, rowsOffsetField, entriesOffsetField, bounds, i, m) - } - - } } class HailContext private( diff --git a/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala b/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala index 1c19a4f548b..9754d1a1f71 100644 --- a/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala +++ b/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala @@ -377,7 +377,7 @@ class ClassBuilder[C]( } } - theClass.newInstance().asInstanceOf[C] + theClass.getDeclaredConstructor().newInstance().asInstanceOf[C] } } } diff --git a/hail/src/main/scala/is/hail/asm4s/Code.scala b/hail/src/main/scala/is/hail/asm4s/Code.scala index 01fdce11089..bc72880bfdb 100644 --- a/hail/src/main/scala/is/hail/asm4s/Code.scala +++ b/hail/src/main/scala/is/hail/asm4s/Code.scala @@ -930,6 +930,8 @@ class CodeInt(val lhs: Code[Int]) extends AnyVal { def toZ: Code[Boolean] = lhs.cne(0) def toS: Code[String] = Code.invokeStatic1[java.lang.Integer, Int, String]("toString", lhs) + + def bitCount: Code[Int] = Code.invokeStatic1[java.lang.Integer, Int, Int]("bitCount", lhs) } class CodeLong(val lhs: Code[Long]) extends AnyVal { @@ -994,6 +996,8 @@ class CodeLong(val lhs: Code[Long]) extends AnyVal { def numberOfLeadingZeros: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfLeadingZeros", lhs) def numberOfTrailingZeros: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfTrailingZeros", lhs) + + def bitCount: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("bitCount", lhs) } class CodeFloat(val lhs: Code[Float]) extends AnyVal { diff --git a/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala b/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala index 68a4e6a2f95..d63c900b626 100644 --- a/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala +++ b/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala @@ -84,14 +84,6 @@ trait ShimRVDSpec extends AbstractRVDSpec { override def partitioner: RVDPartitioner = shim.partitioner - override def read( - ctx: ExecuteContext, - path: String, - requestedType: TStruct, - newPartitioner: Option[RVDPartitioner], - filterIntervals: Boolean - ): RVD = shim.read(ctx, path, requestedType, newPartitioner, filterIntervals) - override def typedCodecSpec: AbstractTypedCodecSpec = shim.typedCodecSpec override def partFiles: Array[String] = shim.partFiles diff --git a/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala b/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala index f6570c39400..7ec0c156c2e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala @@ -113,24 +113,6 @@ case class RVDComponentSpec(rel_path: String) extends ComponentSpec { } def indexed(fs: FS, path: String): Boolean = rvdSpec(fs, path).indexed - - def read( - ctx: ExecuteContext, - path: String, - requestedType: TStruct, - newPartitioner: Option[RVDPartitioner] = None, - filterIntervals: Boolean = false - ): RVD = { - val rvdPath = path + "/" + rel_path - rvdSpec(ctx.fs, path) - .read(ctx, rvdPath, requestedType, newPartitioner, filterIntervals) - } - - def readLocalSingleRow(ctx: ExecuteContext, path: String, requestedType: TStruct): (PStruct, Long) = { - val rvdPath = path + "/" + rel_path - rvdSpec(ctx.fs, path) - .readLocalSingleRow(ctx, rvdPath, requestedType) - } } case class PartitionCountsComponentSpec(counts: Seq[Long]) extends ComponentSpec diff --git a/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala b/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala index f3b0c7f5b6e..ea65427bce2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala @@ -279,7 +279,12 @@ object BinarySearch { runSearchBounded[T](cb, haystack, compare, 0, haystack.loadLength(), found, notFound) } -class BinarySearch[C](mb: EmitMethodBuilder[C], containerType: SContainer, eltType: EmitType, keyOnly: Boolean) { +class BinarySearch[C](mb: EmitMethodBuilder[C], + containerType: SContainer, + eltType: EmitType, + getKey: (EmitCodeBuilder, EmitValue) => EmitValue, + bound: String = "lower", + ltF: CodeOrdering.F[Boolean] = null) { val containerElementType: EmitType = containerType.elementEmitType val findElt = mb.genEmitMethod("findElt", FastIndexedSeq[ParamType](containerType.paramType, eltType.paramType), typeInfo[Int]) @@ -289,35 +294,27 @@ class BinarySearch[C](mb: EmitMethodBuilder[C], containerType: SContainer, eltTy val haystack = findElt.getSCodeParam(1).asIndexable val needle = findElt.getEmitParam(cb, 2, null) // no streams - def ltNeedle(x: IEmitCode): Code[Boolean] = if (keyOnly) { - val kt: EmitType = containerElementType.st match { - case s: SBaseStruct => - require(s.size == 2) - s.fieldEmitTypes(0) - case interval: SInterval => - interval.pointEmitType - } - - val keyLT = mb.ecb.getOrderingFunction(kt.st, eltType.st, CodeOrdering.Lt()) - - val key = cb.memoize(x.flatMap(cb) { - case x: SBaseStructValue => - x.loadField(cb, 0) - case x: SIntervalValue => - x.loadStart(cb) - }) - - keyLT(cb, key, needle) - } else { - val lt = mb.ecb.getOrderingFunction(containerElementType.st, eltType.st, CodeOrdering.Lt()) - lt(cb, cb.memoize(x), needle) + val f: ( + EmitCodeBuilder, + SIndexableValue, + IEmitCode => Code[Boolean], + Value[Int], + Value[Int] + ) => Value[Int] = bound match { + case "upper" => BinarySearch.upperBound + case "lower" => BinarySearch.lowerBound } - BinarySearch.lowerBound(cb, haystack, ltNeedle) + f(cb, haystack, { containerElement => + val elementVal = cb.memoize(containerElement, "binary_search_elt") + val compareVal = getKey(cb, elementVal) + val lt = Option(ltF).getOrElse(mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Lt())) + lt(cb, compareVal, needle) + }, 0, haystack.loadLength()) } // check missingness of v before calling - def lowerBound(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = { + def search(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = { cb.memoize(cb.invokeCode[Int](findElt, array, v)) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 558aa9f7fc0..92eac9c3848 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -1104,8 +1104,18 @@ class Emit[C]( case x@LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => emitI(orderedCollection).map(cb) { a => val e = EmitCode.fromI(cb.emb)(cb => this.emitI(elem, cb, region, env, container, loopEnv)) - val bs = new BinarySearch[C](mb, a.st.asInstanceOf[SContainer], e.emitType, keyOnly = onKey) - primitive(bs.lowerBound(cb, a, e)) + val bs = new BinarySearch[C](mb, a.st.asInstanceOf[SContainer], e.emitType, { (cb, elt) => + + if (onKey) { + cb.memoize(elt.toI(cb).flatMap(cb) { + case x: SBaseStructValue => + x.loadField(cb, 0) + case x: SIntervalValue => + x.loadStart(cb) + }) + } else elt + }) + primitive(bs.search(cb, a, e)) } case x@ArraySort(a, left, right, lessThan) => diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala index 1a5538e1c4e..cc84df14155 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala @@ -719,7 +719,7 @@ class EmitClassBuilder[C]( } } } - val f = theClass.newInstance().asInstanceOf[C] + val f = theClass.getDeclaredConstructor().newInstance().asInstanceOf[C] f.asInstanceOf[FunctionWithHailClassLoader].addHailClassLoader(hcl) f.asInstanceOf[FunctionWithFS].addFS(fs) f.asInstanceOf[FunctionWithPartitionRegion].addPartitionRegion(region) diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index c29800ee7c1..db5548fbfdf 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -23,7 +23,7 @@ object Interpret { apply(tir, ctx, optimize = true) def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = { - val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR] + val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing lowered.analyzeAndExecute(ctx).asTableValue(ctx) } @@ -209,10 +209,17 @@ object Interpret { case TFloat64 => -xValue.asInstanceOf[Double] } case BitNot() => + assert(x.typ.isInstanceOf[TIntegral]) x.typ match { case TInt32 => ~xValue.asInstanceOf[Int] case TInt64 => ~xValue.asInstanceOf[Long] } + case BitCount() => + assert(x.typ.isInstanceOf[TIntegral]) + x.typ match { + case TInt32 => Integer.bitCount(xValue.asInstanceOf[Int]) + case TInt64 => java.lang.Long.bitCount(xValue.asInstanceOf[Long]) + } } case ApplyComparisonOp(op, l, r) => val lValue = interpret(l, env, args) diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index ddda5d12698..95d1e3b289f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -4,7 +4,7 @@ import is.hail.backend.ExecuteContext import is.hail.expr.ir.functions.{WrappedMatrixToTableFunction, WrappedMatrixToValueFunction} import is.hail.expr.ir._ import is.hail.types._ -import is.hail.types.virtual.{TArray, TBaseStruct, TDict, TInt32, TInterval, TStruct} +import is.hail.types.virtual.{TArray, TBaseStruct, TDict, TInt32, TInterval, TString, TStruct} import is.hail.utils._ object LowerMatrixIR { @@ -102,15 +102,28 @@ object LowerMatrixIR { case CastTableToMatrix(child, entries, cols, colKey) => val lc = lower(ctx, child, ab) - lc.mapRows( - irIf('row (Symbol(entries)).isNA) { - irDie("missing entry array unsupported in 'to_matrix_table_row_major'", lc.typ.rowType) - } { - irIf('row (Symbol(entries)).len.cne('global (Symbol(cols)).len)) { - irDie("length mismatch between entry array and column array in 'to_matrix_table_row_major'", lc.typ.rowType) - } { - 'row - } + val row = Ref("row", lc.typ.rowType) + val glob = Ref("global", lc.typ.globalType) + TableMapRows( + lc, + bindIR(GetField(row, entries)) { entries => + If(IsNA(entries), + Die("missing entry array unsupported in 'to_matrix_table_row_major'", row.typ), + bindIRs(ArrayLen(entries), ArrayLen(GetField(glob, cols))) { case Seq(entriesLen, colsLen) => + If(entriesLen cne colsLen, + Die( + strConcat( + Str("length mismatch between entry array and column array in 'to_matrix_table_row_major': "), + invoke("str", TString, entriesLen), + Str(" entries, "), + invoke("str", TString, colsLen), + Str(" cols, at "), + invoke("str", TString, SelectFields(row, child.typ.key)) + ), row.typ, -1), + row + ) + } + ) } ).rename(Map(entries -> entriesFieldName), Map(cols -> colsFieldName)) diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala index 4660d35d0f0..555af680d5c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala @@ -69,6 +69,6 @@ object LowerOrInterpretNonCompilable { } } - rewrite(ir, mutable.HashMap.empty) + rewrite(ir.noSharing, mutable.HashMap.empty) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index b1154a08248..d79b9acaf76 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -12,7 +12,7 @@ import is.hail.expr.ir.streams.{StreamArgType, StreamProducer} import is.hail.io._ import is.hail.io.avro.AvroTableReader import is.hail.io.fs.FS -import is.hail.io.index.{IndexReadIterator, IndexReader, IndexReaderBuilder, LeafChild} +import is.hail.io.index.{IndexReader, IndexReaderBuilder, LeafChild, StagedIndexReader} import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, BlockMatrixReadRowBlockedRDD} import is.hail.rvd._ import is.hail.sparkextras.ContextRDD @@ -628,85 +628,74 @@ case class PartitionNativeReaderIndexed(spec: AbstractTypedCodecSpec, indexSpec: val mb = cb.emb - val (eltType, makeDec) = spec.buildDecoder(ctx, requestedType) - - val (keyType, annotationType) = indexSpec.types - val (leafPType: PStruct, leafDec) = indexSpec.leafCodec.buildDecoder(ctx, indexSpec.leafCodec.encodedVirtualType) - val (intPType: PStruct, intDec) = indexSpec.internalNodeCodec.buildDecoder(ctx, indexSpec.internalNodeCodec.encodedVirtualType) - val mkIndexReader = IndexReaderBuilder.withDecoders(leafDec, intDec, keyType, annotationType, leafPType, intPType) - - val makeIndexCode = mb.getObject[Function5[HailClassLoader, FS, String, Int, RegionPool, IndexReader]](mkIndexReader) - val makeDecCode = mb.getObject[(InputStream, HailClassLoader) => Decoder](makeDec) + val index = new StagedIndexReader(cb.emb, indexSpec) context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => - val getIndexReader: Code[String] => Code[IndexReader] = { (indexPath: Code[String]) => - Code.checkcast[IndexReader]( - makeIndexCode.invoke[AnyRef, AnyRef, AnyRef, AnyRef, AnyRef, AnyRef]( - "apply", mb.getHailClassLoader, mb.getFS, indexPath, Code.boxInt(8), mb.ecb.pool())) - } - - val next = mb.newLocal[Long]("pnr_next") - val idxr = mb.genFieldThisRef[IndexReader]("pnri_idx_reader") - val it = mb.genFieldThisRef[IndexReadIterator]("pnri_idx_iterator") + val nToRead = mb.genFieldThisRef[Long]("n_to_read") + val ib = mb.genFieldThisRef[InputBuffer]("buffer") val region = mb.genFieldThisRef[Region]("pnr_region") + val decodedRow = cb.emb.newPField("rowsValue", spec.encodedType.decodedSType(requestedType)) + val producer = new StreamProducer { - override val length: Option[EmitCodeBuilder => Code[Int]] = None + override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => nToRead.toI) override def initialize(cb: EmitCodeBuilder): Unit = { - cb.assign(idxr, getIndexReader(ctxStruct + val indexPath = ctxStruct .loadField(cb, "indexPath") .get(cb) .asString - .loadString(cb))) - cb.assign(it, - Code.newInstance8[IndexReadIterator, - HailClassLoader, - (InputStream, HailClassLoader) => Decoder, - Region, - InputStream, - IndexReader, - String, - Interval, - InputMetrics]( - mb.ecb.getHailClassLoader, - makeDecCode, - region, - mb.open(ctxStruct.loadField(cb, "partitionPath") - .get(cb) - .asString - .loadString(cb), true), - idxr, - Code._null[String], - ctxStruct.loadField(cb, "interval") - .consumeCode[Interval](cb, - cb.memoize(Code._fatal[Interval]("")), - { pc => - val pt = PType.canonical(pc.st.storageType()).asInstanceOf[PInterval] - val copied = pc.copyToRegion(cb, region, SIntervalPointer(pt)).asInterval - val javaInterval = coerce[Interval](StringFunctions.svalueToJavaValue(cb, region, copied)) - cb.memoize(Code.invokeScalaObject1[AnyRef, Interval]( - RVDPartitioner.getClass, - "irRepresentationToInterval", - javaInterval)) - } - ), - Code._null[InputMetrics] - )) + .loadString(cb) + val partitionPath = ctxStruct + .loadField(cb, "partitionPath") + .get(cb) + .asString + .loadString(cb) + val interval = ctxStruct + .loadField(cb, "interval") + .get(cb) + .asInterval + index.initialize(cb, indexPath) + + val indexResult = index.queryInterval(cb, partitionRegion, interval) + val n = indexResult.loadField(cb, 0) + .get(cb) + .asInt64 + .value + cb.assign(nToRead, n) + + cb.assign(ib, spec.buildCodeInputBuffer(Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.open(partitionPath, false)))) + cb.ifx(n > 0, { + val firstOffset = indexResult.loadField(cb, 1) + .get(cb) + .asBaseStruct + .loadField(cb, "offset") + .get(cb) + .asInt64 + .value + + cb += ib.seek(firstOffset) + }) + index.close(cb) } override val elementRegion: Settable[Region] = region override val requiresMemoryManagementPerElement: Boolean = true override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.ifx(!it.invoke[Boolean]("hasNext"), cb.goto(LendOfStream)) - cb.assign(next, it.invoke[Long]("_next")) + cb.ifx(nToRead <= 0L, cb.goto(LendOfStream)) + val next = ib.readByte() + cb.ifx(next cne 1, cb._fatal(s"bad buffer state!")) + cb.assign(nToRead, nToRead - 1L) + cb.assign(decodedRow, spec.encodedType.buildDecoder(requestedType, cb.emb.ecb)(cb, region, ib)) cb.goto(LproduceElementDone) } - override val element: EmitCode = EmitCode.fromI(mb)(cb => IEmitCode.present(cb, eltType.loadCheapSCode(cb, next))) + override val element: EmitCode = EmitCode.fromI(mb) { cb => + IEmitCode.present(cb, decodedRow) + } - override def close(cb: EmitCodeBuilder): Unit = cb += it.invoke[Unit]("close") + override def close(cb: EmitCodeBuilder): Unit = cb += ib.close() } SStreamValue(producer) } @@ -811,7 +800,6 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe } } - case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, specRight: AbstractTypedCodecSpec, indexSpecLeft: AbstractIndexSpec, indexSpecRight: AbstractIndexSpec, key: IndexedSeq[String]) extends PartitionReader { @@ -861,114 +849,118 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, val (leftRType, rightRType) = splitRequestedTypes(requestedType) - val makeIndexCode = { - val (keyType, annotationType) = indexSpecLeft.types - val (leafPType: PStruct, leafDec) = indexSpecLeft.leafCodec.buildDecoder(ctx, indexSpecLeft.leafCodec.encodedVirtualType) - val (intPType: PStruct, intDec) = indexSpecLeft.internalNodeCodec.buildDecoder(ctx, indexSpecLeft.internalNodeCodec.encodedVirtualType) - val mkIndexReader = IndexReaderBuilder.withDecoders(leafDec, intDec, keyType, annotationType, leafPType, intPType) + val rowsDec = specLeft.encodedType.buildDecoder(leftRType, cb.emb.ecb) + val entriesDec = specRight.encodedType.buildDecoder(rightRType, cb.emb.ecb) - mb.getObject[Function5[HailClassLoader, FS, String, Int, RegionPool, IndexReader]](mkIndexReader) - } + val rowsValue = cb.emb.newPField("rowsValue", specLeft.encodedType.decodedSType(leftRType)) + val entriesValue = cb.emb.newPField("entriesValue", specRight.encodedType.decodedSType(rightRType)) val leftOffsetFieldIndex = indexSpecLeft.offsetFieldIndex val rightOffsetFieldIndex = indexSpecRight.offsetFieldIndex - context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => + val rowsBuffer = cb.emb.genFieldThisRef[InputBuffer]("rows_inputbuffer") + val entriesBuffer = cb.emb.genFieldThisRef[InputBuffer]("entries_inputbuffer") - def getIndexReader(cb: EmitCodeBuilder, ctxMemo: SBaseStructValue): Code[IndexReader] = { - val indexPath = ctxMemo - .loadField(cb, "indexPath") - .handle(cb, cb._fatal("")) - .asString - .loadString(cb) - Code.checkcast[IndexReader]( - makeIndexCode.invoke[AnyRef, AnyRef, AnyRef, AnyRef, AnyRef, AnyRef]( - "apply", mb.getHailClassLoader, mb.getFS, indexPath, Code.boxInt(8), cb.emb.ecb.pool())) - } + val index = new StagedIndexReader(cb.emb, indexSpecLeft) - def getInterval(cb: EmitCodeBuilder, region: Value[Region], ctxMemo: SBaseStructValue): Code[Interval] = { - Code.invokeScalaObject1[AnyRef, Interval]( - RVDPartitioner.getClass, - "irRepresentationToInterval", - StringFunctions.svalueToJavaValue(cb, region, ctxMemo.loadField(cb, "interval").get(cb))) - } - - val indexReader = cb.emb.genFieldThisRef[IndexReader]("idx_reader") - val idx = cb.emb.genFieldThisRef[BufferedIterator[LeafChild]]("idx") - val rowsDec = specLeft.encodedType.buildDecoder(leftRType, cb.emb.ecb) - val entriesDec = specRight.encodedType.buildDecoder(rightRType, cb.emb.ecb) - - val rowsValue = cb.emb.newPField("rowsValue", specLeft.encodedType.decodedSType(leftRType)) - val entriesValue = cb.emb.newPField("entriesValue", specRight.encodedType.decodedSType(rightRType)) + context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => - val rowsBuffer = cb.emb.genFieldThisRef[InputBuffer]("rows_inputbuffer") - val entriesBuffer = cb.emb.genFieldThisRef[InputBuffer]("entries_inputbuffer") + val nToRead = mb.genFieldThisRef[Long]("n_to_read") - val region = cb.emb.genFieldThisRef[Region]("zipped_indexed_reader_region") + val region = mb.genFieldThisRef[Region]("pnr_region") val producer = new StreamProducer { - override val length: Option[EmitCodeBuilder => Code[Int]] = None + override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => nToRead.toI) override def initialize(cb: EmitCodeBuilder): Unit = { + val indexPath = ctxStruct + .loadField(cb, "indexPath") + .get(cb) + .asString + .loadString(cb) + val interval = ctxStruct + .loadField(cb, "interval") + .get(cb) + .asInterval + index.initialize(cb, indexPath) + + val indexResult = index.queryInterval(cb, partitionRegion, interval) + val n = indexResult.loadField(cb, 0) + .get(cb) + .asInt64 + .value + cb.assign(nToRead, n) + cb.assign(rowsBuffer, specLeft.buildCodeInputBuffer( Code.newInstance[ByteTrackingInputStream, InputStream]( mb.open(ctxStruct.loadField(cb, "leftPartitionPath") - .handle(cb, cb._fatal("")) + .get(cb) .asString .loadString(cb), true)))) cb.assign(entriesBuffer, specRight.buildCodeInputBuffer( Code.newInstance[ByteTrackingInputStream, InputStream]( mb.open(ctxStruct.loadField(cb, "rightPartitionPath") - .handle(cb, cb._fatal("")) + .get(cb) .asString .loadString(cb), true)))) - cb.assign(indexReader, getIndexReader(cb, ctxStruct)) - cb.assign(idx, - indexReader - .invoke[Interval, Iterator[LeafChild]]("queryByInterval", getInterval(cb, partitionRegion, ctxStruct)) - .invoke[BufferedIterator[LeafChild]]("buffered")) - - cb.ifx(idx.invoke[Boolean]("hasNext"), { - val lcHead = cb.newLocal[LeafChild]("lcHead", idx.invoke[LeafChild]("head")) - - leftOffsetFieldIndex match { - case Some(rowOffsetIdx) => - cb += rowsBuffer.invoke[Long, Unit]("seek", lcHead.invoke[Int, Long]("longChild", const(rowOffsetIdx))) + cb.ifx(n > 0, { + val leafNode = indexResult.loadField(cb, 1) + .get(cb) + .asBaseStruct + + val leftSeekAddr = leftOffsetFieldIndex match { + case Some(offsetIdx) => + leafNode + .loadField(cb, "annotation") + .get(cb) + .asBaseStruct + .loadField(cb, offsetIdx) + .get(cb) case None => - cb += rowsBuffer.invoke[Long, Unit]("seek", lcHead.invoke[Long]("recordOffset")) + leafNode + .loadField(cb, "offset") + .get(cb) } - - rightOffsetFieldIndex match { - case Some(rowOffsetIdx) => - cb += entriesBuffer.invoke[Long, Unit]("seek", lcHead.invoke[Int, Long]("longChild", const(rowOffsetIdx))) + cb += rowsBuffer.seek(leftSeekAddr.asInt64.value) + + val rightSeekAddr = rightOffsetFieldIndex match { + case Some(offsetIdx) => + leafNode + .loadField(cb, "annotation") + .get(cb) + .asBaseStruct + .loadField(cb, offsetIdx) + .get(cb) case None => - cb += entriesBuffer.invoke[Long, Unit]("seek", lcHead.invoke[Long]("recordOffset")) + leafNode + .loadField(cb, "offset") + .get(cb) } + cb += entriesBuffer.seek(rightSeekAddr.asInt64.value) }) + + index.close(cb) } override val elementRegion: Settable[Region] = region override val requiresMemoryManagementPerElement: Boolean = true override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - val cr = cb.newLocal[Int]("cr", rowsBuffer.invoke[Byte]("readByte").toI) - val ce = cb.newLocal[Int]("ce", entriesBuffer.invoke[Byte]("readByte").toI) - cb.ifx(ce.cne(cr), cb._fatal(s"mismatch between streams in zipped indexed reader")) - - cb.ifx(cr.ceq(0) || !idx.invoke[Boolean]("hasNext"), cb.goto(LendOfStream)) - - cb += Code.toUnit(idx.invoke[LeafChild]("next")) - + cb.ifx(nToRead <= 0L, cb.goto(LendOfStream)) + val nextRow = rowsBuffer.readByte() + cb.ifx(nextRow cne 1, cb._fatal(s"bad rows buffer state!")) + val nextEntries = entriesBuffer.readByte() + cb.ifx(nextEntries cne 1, cb._fatal(s"bad entries buffer state!")) + cb.assign(nToRead, nToRead - 1L) cb.assign(rowsValue, rowsDec(cb, region, rowsBuffer)) cb.assign(entriesValue, entriesDec(cb, region, entriesBuffer)) - cb.goto(LproduceElementDone) } + override val element: EmitCode = EmitCode.fromI(mb)(cb => IEmitCode.present(cb, SInsertFieldsStruct.merge(cb, rowsValue.asBaseStruct, entriesValue.asBaseStruct))) override def close(cb: EmitCodeBuilder): Unit = { - indexReader.invoke[Unit]("close") rowsBuffer.invoke[Unit]("close") entriesBuffer.invoke[Unit]("close") } @@ -1006,21 +998,7 @@ class TableNativeReader( } def apply(tr: TableRead, ctx: ExecuteContext): TableValue = { - val (globalType, globalsOffset) = spec.globalsComponent.readLocalSingleRow(ctx, params.path, tr.typ.globalType) - val rvd = if (tr.dropRows) { - RVD.empty(tr.typ.canonicalRVDType) - } else { - val partitioner = if (filterIntervals) - params.options.map(opts => RVDPartitioner.union(tr.typ.keyType, opts.intervals, tr.typ.key.length - 1)) - else - params.options.map(opts => new RVDPartitioner(tr.typ.keyType, opts.intervals)) - val rvd = spec.rowsComponent.read(ctx, params.path, tr.typ.rowType, partitioner, filterIntervals) - if (!rvd.typ.key.startsWith(tr.typ.key)) - fatal(s"Error while reading table ${params.path}: legacy table written without key." + - s"\n Read and write with version 0.2.70 or earlier") - rvd - } - TableValue(ctx, tr.typ, BroadcastRow(ctx, RegionValue(ctx.r, globalsOffset), globalType.setRequired(true).asInstanceOf[PStruct]), rvd) + TableExecuteIntermediate(lower(ctx, tr.typ)).asTableValue(ctx) } override def toJValue: JValue = { @@ -1116,40 +1094,7 @@ case class TableNativeZippedReader( } def apply(tr: TableRead, ctx: ExecuteContext): TableValue = { - val fs = ctx.fs - val (globalPType: PStruct, globalsOffset) = specLeft.globalsComponent.readLocalSingleRow(ctx, pathLeft, tr.typ.globalType) - val rvd = if (tr.dropRows) { - RVD.empty(tr.typ.canonicalRVDType) - } else { - val partitioner = if (filterIntervals) - intervals.map(i => RVDPartitioner.union(tr.typ.keyType, i, tr.typ.key.length - 1)) - else - intervals.map(i => new RVDPartitioner(tr.typ.keyType, i)) - if (tr.typ.rowType.fieldNames.forall(f => !rightFieldSet.contains(f))) { - specLeft.rowsComponent.read(ctx, pathLeft, tr.typ.rowType, partitioner, filterIntervals) - } else if (tr.typ.rowType.fieldNames.forall(f => !leftFieldSet.contains(f))) { - specRight.rowsComponent.read(ctx, pathRight, tr.typ.rowType, partitioner, filterIntervals) - } else { - val rvdSpecLeft = specLeft.rowsComponent.rvdSpec(fs, pathLeft) - val rvdSpecRight = specRight.rowsComponent.rvdSpec(fs, pathRight) - val rvdPathLeft = specLeft.rowsComponent.absolutePath(pathLeft) - val rvdPathRight = specRight.rowsComponent.absolutePath(pathRight) - - val leftRType = tr.typ.rowType.filter(f => leftFieldSet.contains(f.name))._1 - val rightRType = tr.typ.rowType.filter(f => rightFieldSet.contains(f.name))._1 - - AbstractRVDSpec.readZipped(ctx, - rvdSpecLeft, rvdSpecRight, - rvdPathLeft, rvdPathRight, - partitioner, filterIntervals, - tr.typ.rowType, - leftRType, rightRType, - tr.typ.key, - fieldInserter) - } - } - - TableValue(ctx, tr.typ, BroadcastRow(ctx, RegionValue(ctx.r, globalsOffset), globalPType.setRequired(true).asInstanceOf[PStruct]), rvd) + TableExecuteIntermediate(lower(ctx, tr.typ)).asTableValue(ctx) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { diff --git a/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala b/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala index 410d271bd30..19b68a30647 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala @@ -17,12 +17,12 @@ import is.hail.rvd.{AbstractRVDSpec, IndexSpec, RVDPartitioner, RVDSpecMaker} import is.hail.types.encoded.EType import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SContainer, SStringValue, SVoidValue} import is.hail.types.physical._ -import is.hail.types.physical.stypes.{EmitType, SCode} -import is.hail.types.physical.stypes.concrete.{SStackStruct, SJavaArrayString, SJavaArrayStringValue} +import is.hail.types.physical.stypes.{EmitType, SCode, SValue, SSettable} +import is.hail.types.physical.stypes.concrete.{SStackStruct, SJavaArrayString, SJavaArrayStringValue, SSubsetStruct, SSubsetStructValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64Value, SInt64} import is.hail.types.virtual._ -import is.hail.types.{RIterable, RStruct, RTable, TableType, TypeWithRequiredness} +import is.hail.types.{RIterable, RStruct, RTuple, RTable, TableType, TypeWithRequiredness} import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream import is.hail.variant.ReferenceGenome @@ -32,7 +32,7 @@ import org.json4s.{DefaultFormats, Formats, JBool, JObject, ShortTypeHints} object TableWriter { implicit val formats: Formats = new DefaultFormats() { override val typeHints = ShortTypeHints( - List(classOf[TableNativeWriter], classOf[TableTextWriter]), typeHintFieldName = "name") + List(classOf[TableNativeFanoutWriter], classOf[TableNativeWriter], classOf[TableTextWriter]), typeHintFieldName = "name") } } @@ -169,6 +169,23 @@ case class TableNativeWriter( } } +object PartitionNativeWriter { + val ctxType = TString + def fullReturnType(keyType: TStruct): TStruct = TStruct( + "filePath" -> TString, + "partitionCounts" -> TInt64, + "distinctlyKeyed" -> TBoolean, + "firstKey" -> keyType, + "lastKey" -> keyType, + "partitionByteSize" -> TInt64 + ) + + def returnType(keyType: TStruct, trackTotalBytes: Boolean): TStruct = { + val t = PartitionNativeWriter.fullReturnType(keyType) + if (trackTotalBytes) t else t.filterSet(Set("partitionByteSize"), include=false)._1 + } +} + case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, keyFields: IndexedSeq[String], partPrefix: String, index: Option[(String, PStruct)] = None, localDir: Option[String] = None, trackTotalBytes: Boolean = false) extends PartitionWriter { @@ -179,13 +196,9 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, keyFields: Indexe val keyType = spec.encodedVirtualType.asInstanceOf[TStruct].select(keyFields)._1 - def ctxType: Type = TString - val returnType: Type = { - val types = Seq("filePath" -> TString, "partitionCounts" -> TInt64, "distinctlyKeyed" -> TBoolean, - "firstKey" -> keyType, "lastKey" -> keyType - ) ++ Some("partitionByteSize" -> TInt64).filter(_ => trackTotalBytes) - TStruct(types: _*) - } + def ctxType = PartitionNativeWriter.ctxType + val returnType = PartitionNativeWriter.returnType(keyType, trackTotalBytes) + def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { val rs = r.asInstanceOf[RStruct] val rKeyType = streamType.elementType.asInstanceOf[RStruct].select(keyFields.toArray) @@ -201,84 +214,31 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, keyFields: Indexe throw new LowererUnsupportedOperation("stageLocally option not yet implemented") def ifIndexed[T >: Null](obj: => T): T = if (hasIndex) obj else null - def consumeStream( - ctx: ExecuteContext, - cb: EmitCodeBuilder, - stream: StreamProducer, - context: EmitCode, - region: Value[Region]): IEmitCode = { - - val mb = cb.emb - - val indexKeyType = ifIndexed { index.get._2 } - val indexWriter = ifIndexed { StagedIndexWriter.withDefaults(indexKeyType, mb.ecb) } - - context.toI(cb).map(cb) { case ctx: SStringValue => - val filename = mb.newLocal[String]("filename") - val os = mb.newLocal[ByteTrackingOutputStream]("write_os") - val ob = mb.newLocal[OutputBuffer]("write_ob") - val n = mb.newLocal[Long]("partition_count") - val byteCount = if (trackTotalBytes) Some(mb.newPLocal("partition_byte_count", SInt64)) else None - val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") + class StreamConsumer( + _ctx: SValue, + private[this] val cb: EmitCodeBuilder, + private[this] val region: Value[Region] + ) { + private[this] val ctx = _ctx.asString + private[this] val mb = cb.emb + private[this] val indexKeyType = ifIndexed { index.get._2 } + private[this] val indexWriter = ifIndexed { StagedIndexWriter.withDefaults(indexKeyType, mb.ecb) } + private[this] val filename = mb.newLocal[String]("filename") + private[this] val os = mb.newLocal[ByteTrackingOutputStream]("write_os") + private[this] val ob = mb.newLocal[OutputBuffer]("write_ob") + private[this] val n = mb.newLocal[Long]("partition_count") + private[this] val byteCount = if (trackTotalBytes) Some(mb.newPLocal("partition_byte_count", SInt64)) else None + private[this] val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") + private[this] val keyEmitType = EmitType(spec.decodedPType(keyType).sType, false) + private[this] val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType) + private[this] val lastSeenSettable = mb.newEmitLocal("pnw_lastSeen", keyEmitType) + + def setup(): Unit = { cb.assign(distinctlyKeyed, !keyFields.isEmpty) // True until proven otherwise, if there's a key to care about at all. - - val keyEmitType = EmitType(spec.decodedPType(keyType).sType, false) - - val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType) - val lastSeenSettable = mb.newEmitLocal("pnw_lastSeen", keyEmitType) // Start off missing, we will use this to determine if we haven't processed any rows yet. cb.assign(firstSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st)) cb.assign(lastSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st)) - - def writeFile(cb: EmitCodeBuilder, codeRow: EmitCode): Unit = { - val row = codeRow.toI(cb).get(cb, "row can't be missing").asBaseStruct - - if (hasIndex) { - indexWriter.add(cb, { - val indexKeyPType = index.get._2 - IEmitCode.present(cb, indexKeyPType.asInstanceOf[PCanonicalBaseStruct] - .constructFromFields(cb, stream.elementRegion, - indexKeyPType.fields.map{ f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }, - deepCopy = true)) - }, - ob.invoke[Long]("indexOffset"), - IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L))) - } - - val key = SStackStruct.constructFromArgs(cb, stream.elementRegion, keyType, keyType.fields.map { f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }:_*) - - if (!keyFields.isEmpty) { - cb.ifx(distinctlyKeyed, { - lastSeenSettable.loadI(cb).consume(cb, { - // If there's no last seen, we are in the first row. - cb.assign(firstSeenSettable, EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st))) - }, { lastSeen => - val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering(cb.emb.ecb, lastSeenSettable.st, key.st) - val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) - cb.ifx(equalToLast.asInstanceOf[Value[Boolean]], { - cb.assign(distinctlyKeyed, false) - }) - }) - }) - cb.assign(lastSeenSettable, IEmitCode.present(cb, key.copyToRegion(cb, region, lastSeenSettable.st))) - } - - cb += ob.writeByte(1.asInstanceOf[Byte]) - - spec.encodedType.buildEncoder(row.st, cb.emb.ecb) - .apply(cb, row, ob) - - cb.assign(n, n + 1L) - byteCount.foreach { bc => - cb.assign(bc, SCode.add(cb, bc, row.sizeToStoreInBytes(cb), true)) - } - } - cb.assign(filename, ctx.loadString(cb)) if (hasIndex) { val indexFile = cb.newLocal[String]("indexFile") @@ -289,11 +249,57 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, keyFields: Indexe cb.assign(os, Code.newInstance[ByteTrackingOutputStream, OutputStream](mb.create(filename))) cb.assign(ob, spec.buildCodeOutputBuffer(Code.checkcast[OutputStream](os))) cb.assign(n, 0L) + } + + def consumeElement(cb: EmitCodeBuilder, codeRow: SValue, elementRegion: Settable[Region]): Unit = { + val row = codeRow.asBaseStruct + + if (hasIndex) { + indexWriter.add(cb, { + val indexKeyPType = index.get._2 + IEmitCode.present(cb, indexKeyPType.asInstanceOf[PCanonicalBaseStruct] + .constructFromFields(cb, elementRegion, + indexKeyPType.fields.map{ f => + EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) + }, + deepCopy = true)) + }, + ob.invoke[Long]("indexOffset"), + IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L))) + } - stream.memoryManagedConsume(region, cb) { cb => - writeFile(cb, stream.element) + val key = SStackStruct.constructFromArgs(cb, elementRegion, keyType, keyType.fields.map { f => + EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) + }:_*) + + if (!keyFields.isEmpty) { + cb.ifx(distinctlyKeyed, { + lastSeenSettable.loadI(cb).consume(cb, { + // If there's no last seen, we are in the first row. + cb.assign(firstSeenSettable, EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st))) + }, { lastSeen => + val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering(cb.emb.ecb, lastSeenSettable.st, key.st) + val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) + cb.ifx(equalToLast.asInstanceOf[Value[Boolean]], { + cb.assign(distinctlyKeyed, false) + }) + }) + }) + cb.assign(lastSeenSettable, IEmitCode.present(cb, key.copyToRegion(cb, region, lastSeenSettable.st))) } + cb += ob.writeByte(1.asInstanceOf[Byte]) + + spec.encodedType.buildEncoder(row.st, cb.emb.ecb) + .apply(cb, row, ob) + + cb.assign(n, n + 1L) + byteCount.foreach { bc => + cb.assign(bc, SCode.add(cb, bc, row.sizeToStoreInBytes(cb), true)) + } + } + + def result(): SValue = { cb += ob.writeByte(0.asInstanceOf[Byte]) if (hasIndex) indexWriter.close(cb) @@ -305,10 +311,29 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, keyFields: Indexe EmitCode.present(mb, new SInt64Value(n)), EmitCode.present(mb, new SBooleanValue(distinctlyKeyed)), firstSeenSettable, - lastSeenSettable) ++ byteCount.map(EmitCode.present(mb, _)) + lastSeenSettable + ) ++ byteCount.map(EmitCode.present(mb, _)) + SStackStruct.constructFromArgs(cb, region, returnType.asInstanceOf[TBaseStruct], values: _*) } } + + def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region] + ): IEmitCode = { + val ctx = context.toI(cb).get(cb) + val consumer = new StreamConsumer(ctx, cb, region) + consumer.setup() + stream.memoryManagedConsume(region, cb) { cb => + val element = stream.element.toI(cb).get(cb, "row can't be missing") + consumer.consumeElement(cb, element, stream.elementRegion) + } + IEmitCode.present(cb, consumer.result()) + } } case class RVDSpecWriter(path: String, spec: RVDSpecMaker) extends MetadataWriter { @@ -571,6 +596,182 @@ case class TableTextFinalizer(outputPath: String, rowType: TStruct, delimiter: S } } +class FanoutWriterTarget( + val field: String, + val path: String, + val rowSpec: TypedCodecSpec, + val keyPType: PStruct, + val tableType: TableType, + val rowWriter: PartitionNativeWriter, + val globalWriter: PartitionNativeWriter +) + +case class TableNativeFanoutWriter( + val path: String, + val fields: IndexedSeq[String], + overwrite: Boolean = true, + stageLocally: Boolean = false, + codecSpecJSONStr: String = null +) extends TableWriter { + override def apply(ctx: ExecuteContext, mv: TableValue): Unit = + throw new UnsupportedOperationException("TableNativeFanoutWriter only supports lowered execution") + + override def lower( + ctx: ExecuteContext, + ts: TableStage, + t: TableIR, + r: RTable, + relationalLetsAbove: Map[String, IR] + ): IR = { + val partitioner = ts.partitioner + val bufferSpec = BufferSpec.parseOrDefault(codecSpecJSONStr) + val globalSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(t.typ.globalType, r.globalType), t.typ.globalType, bufferSpec) + val targets = { + val rowType = t.typ.rowType + val rowRType = r.rowType + val keyType = partitioner.kType + val keyFields = keyType.fieldNames + + fields.map { field => + val targetPath = path + "/" + field + val fieldAndKey = (field +: keyFields) + val targetRowType = rowType.typeAfterSelectNames(fieldAndKey) + val targetRowRType = rowRType.select(fieldAndKey) + val rowSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(targetRowType, targetRowRType), targetRowType, bufferSpec) + val keyPType = coerce[PStruct](rowSpec.decodedPType(keyType)) + val tableType = TableType(targetRowType, keyFields, t.typ.globalType) + val rowWriter = PartitionNativeWriter( + rowSpec, + keyFields, + s"$targetPath/rows/parts/", + Some(s"$targetPath/index/" -> keyPType), + if (stageLocally) Some(ctx.localTmpdir) else None + ) + val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$targetPath/globals/parts/", None, None) + new FanoutWriterTarget(field, targetPath, rowSpec, keyPType, tableType, rowWriter, globalWriter) + }.toFastIndexedSeq + } + + val writeTables = ts.mapContexts { oldCtx => + val d = digitsNeeded(ts.numPartitions) + val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"${ partFile(d, i) }-").toFastIndexedSeq) + + zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => + MakeStruct(FastSeq( + "oldCtx" -> ctxElt, + "writeCtx" -> pf) + ) + } + }( + GetField(_, "oldCtx") + ).mapCollectWithContextsAndGlobals(relationalLetsAbove, "table_native_fanout_writer") { (rows, ctxRef) => + val file = GetField(ctxRef, "writeCtx") + WritePartition(rows, file + UUID4(), new PartitionNativeFanoutWriter(targets)) + } { (parts, globals) => + bindIR(parts) { fileCountAndDistinct => + Begin(targets.zipWithIndex.map { case (target, index) => + Begin(FastIndexedSeq( + WriteMetadata( + MakeArray( + GetField( + WritePartition( + MakeStream(FastSeq(globals), TStream(globals.typ)), + Str(partFile(1, 0)), + target.globalWriter + ), + "filePath" + ) + ), + RVDSpecWriter(s"${target.path}/globals", RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(1))) + ), + WriteMetadata( + ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => GetField(GetTupleElement(fc, index), "filePath") }), + RVDSpecWriter( + s"${target.path}/rows", + RVDSpecMaker( + target.rowSpec, + partitioner, + IndexSpec.emptyAnnotation("../index", coerce[PStruct](target.keyPType)) + ) + ) + ), + WriteMetadata( + ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => + SelectFields( + GetTupleElement(fc, index), + Seq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") + ) + }), + TableSpecWriter(target.path, target.tableType, "rows", "globals", "references", log = true) + ) + )) + }.toFastIndexedSeq) + } + } + + targets.foldLeft(writeTables) { (rest: IR, target: FanoutWriterTarget) => + RelationalWriter.scoped( + target.path, overwrite, Some(target.tableType) + )( + rest + ) + } + } + + override def canLowerEfficiently: Boolean = true +} + +class PartitionNativeFanoutWriter( + targets: IndexedSeq[FanoutWriterTarget] +) extends PartitionWriter { + def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region] + ): IEmitCode = { + val ctx = context.toI(cb).get(cb) + val consumers = targets.map { target => + new target.rowWriter.StreamConsumer(ctx, cb, region) + } + + consumers.foreach(_.setup()) + stream.memoryManagedConsume(region, cb) { cb => + val row = stream.element.toI(cb).get(cb, "row can't be missing") + + (consumers zip targets).foreach { case (consumer, target) => + consumer.consumeElement( + cb, + row.asBaseStruct.subset((target.keyPType.fieldNames :+ target.field):_*), + stream.elementRegion) + } + } + IEmitCode.present( + cb, + SStackStruct.constructFromArgs( + cb, + region, + returnType, + consumers.map(consumer => EmitCode.present(cb.emb, consumer.result())): _* + ) + ) + } + + def ctxType = TString + + def returnType: TTuple = + TTuple(targets.map(target => target.rowWriter.returnType):_*) + + def unionTypeRequiredness(returnType: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + val targetReturnTypes = returnType.asInstanceOf[RTuple].fields.map(_.typ) + + ((targetReturnTypes) zip targets).foreach { case (returnType, target) => + target.rowWriter.unionTypeRequiredness(returnType, ctxType, streamType) + } + } +} + object WrappedMatrixNativeMultiWriter { implicit val formats: Formats = MatrixNativeMultiWriter.formats + ShortTypeHints(List(classOf[WrappedMatrixNativeMultiWriter])) + diff --git a/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala b/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala index 22fdb635936..abb7df4d0bc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala @@ -15,6 +15,7 @@ object UnaryOp { case (Negate(), t@(TInt32 | TInt64 | TFloat32 | TFloat64)) => t case (Bang(), TBoolean) => TBoolean case (BitNot(), t@(TInt32 | TInt64)) => t + case (BitCount(), TInt32 | TInt64) => TInt32 } def returnTypeOption(op: UnaryOp, t: Type): Option[Type] = @@ -44,6 +45,7 @@ object UnaryOp { op match { case Negate() => -xx case BitNot() => ~xx + case BitCount() => xx.bitCount case _ => incompatible(t, op) } case TInt64 => @@ -51,6 +53,7 @@ object UnaryOp { op match { case Negate() => -xx case BitNot() => ~xx + case BitCount() => xx.bitCount case _ => incompatible(t, op) } case TFloat32 => @@ -72,6 +75,7 @@ object UnaryOp { case "-" | "Negate" => Negate() case "!" | "Bang" => Bang() case "~" | "BitNot" => BitNot() + case "BitCount" => BitCount() } } @@ -79,3 +83,4 @@ sealed trait UnaryOp { } case class Negate() extends UnaryOp { } case class Bang() extends UnaryOp { } case class BitNot() extends UnaryOp +case class BitCount() extends UnaryOp diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index 1e1dad803cb..35df96e7170 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -127,10 +127,38 @@ case class AggElementsAggSig(nested: Seq[PhysicalAggSig]) extends case class ArrayLenAggSig(knownLength: Boolean, nested: Seq[PhysicalAggSig]) extends PhysicalAggSig(AggElementsLengthCheck(), ArrayAggStateSig(nested.map(_.state)), nested.flatMap(sig => sig.allOps).toArray) -case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[PhysicalAggSig]) { +class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit], val init: IR, val seqPerElt: IR, val aggs: Array[PhysicalAggSig]) { val states: Array[AggStateSig] = aggs.map(_.state) val nAggs: Int = aggs.length + + lazy val postAggIR: IR = { + rewriteMap.lookup(original) + } + + def rewriteFromInitBindingRoot(f: IR => IR): IR = { + val irNumberMemo = Memo.empty[Int] + var i = 0 + // depth first search -- either DFS or BFS should work here given IR assumptions + VisitIR(original) { x => + irNumberMemo.bind(x, i) + i += 1 + } + + if (bindingNodesReferenced.m.isEmpty) { + f(rewriteMap.lookup(original)) + // find deepest binding node referenced + } else { + val rewriteRoot = bindingNodesReferenced.m.keys.maxBy(irNumberMemo.lookup) + // only support let nodes here -- other binders like stream operators are undefined behavior + RewriteTopDown.rewriteTopDown(original, { + case ir if RefEquality(ir) == rewriteRoot => + val Let(name, value, body) = ir + Let(name, value, f(rewriteMap.lookup(body))) + }).asInstanceOf[IR] + } + } + def isCommutative: Boolean = { def aggCommutes(agg: PhysicalAggSig): Boolean = agg.allOps.forall(AggIsCommutative(_)) aggs.forall(aggCommutes) @@ -369,27 +397,40 @@ object Extract { val let = new BoxedArrayBuilder[AggLet]() val ref = Ref(resultName, null) val memo = mutable.Map.empty[IR, Int] - val postAgg = extract(ir, ab, seq, let, memo, ref, r, isScan) + + val bindingNodesReferenced = Memo.empty[Unit] + val rewriteMap = Memo.empty[IR] + extract(ir, Env.empty, bindingNodesReferenced, rewriteMap, ab, seq, let, memo, ref, r, isScan) val (initOps, pAggSigs) = ab.result().unzip val rt = TTuple(initOps.map(_.aggSig.resultType): _*) ref._typ = rt - Aggs(postAgg, Begin(initOps), addLets(Begin(seq.result()), let.result()), pAggSigs) + new Aggs(ir, rewriteMap, bindingNodesReferenced, Begin(initOps), addLets(Begin(seq.result()), let.result()), pAggSigs) } - private def extract(ir: IR, ab: BoxedArrayBuilder[(InitOp, PhysicalAggSig)], seqBuilder: BoxedArrayBuilder[IR], letBuilder: BoxedArrayBuilder[AggLet], memo: mutable.Map[IR, Int], result: IR, r: RequirednessAnalysis, isScan: Boolean): IR = { - def extract(node: IR): IR = this.extract(node, ab, seqBuilder, letBuilder, memo, result, r, isScan) + private def extract(ir: IR, env: Env[RefEquality[IR]], bindingNodesReferenced: Memo[Unit], rewriteMap: Memo[IR], ab: BoxedArrayBuilder[(InitOp, PhysicalAggSig)], seqBuilder: BoxedArrayBuilder[IR], letBuilder: BoxedArrayBuilder[AggLet], memo: mutable.Map[IR, Int], result: IR, r: RequirednessAnalysis, isScan: Boolean): IR = { + // the env argument here tracks variable bindings that are accessible to init op arguments def newMemo: mutable.Map[IR, Int] = mutable.Map.empty[IR, Int] - ir match { + def bindInitArgRefs(initArgs: IndexedSeq[IR]): Unit = { + initArgs.foreach { arg => + val fv = FreeVariables(arg, false, false).eval + fv.m.keys + .flatMap { k => env.lookupOption(k) } + .foreach(bindingNodesReferenced.bind(_, ())) + } + } + + val newNode = ir match { case x@AggLet(name, value, body, _) => letBuilder += x - extract(body) + this.extract(body, env, bindingNodesReferenced, rewriteMap, ab, seqBuilder, letBuilder, memo, result, r, isScan) case x: ApplyAggOp if !isScan => val idx = memo.getOrElseUpdate(x, { val i = ab.length val op = x.aggSig.op + bindInitArgRefs(x.initOpArgs) val state = PhysicalAggSig(op, AggStateSig(op, x.initOpArgs, x.seqOpArgs, r)) ab += InitOp(i, x.initOpArgs, state) -> state seqBuilder += SeqOp(i, x.seqOpArgs, state) @@ -400,6 +441,7 @@ object Extract { val idx = memo.getOrElseUpdate(x, { val i = ab.length val op = x.aggSig.op + bindInitArgRefs(x.initOpArgs) val state = PhysicalAggSig(op, AggStateSig(op, x.initOpArgs, x.seqOpArgs, r)) ab += InitOp(i, x.initOpArgs, state) -> state seqBuilder += SeqOp(i, x.seqOpArgs, state) @@ -410,6 +452,7 @@ object Extract { val idx = memo.getOrElseUpdate(x, { val i = ab.length val initOpArgs = IndexedSeq(zero) + bindInitArgRefs(initOpArgs) val seqOpArgs = IndexedSeq(seqOp) val op = Fold() val resultEmitType = r(x).canonicalEmitType(x.typ) @@ -425,7 +468,7 @@ object Extract { case AggFilter(cond, aggIR, _) => val newSeq = new BoxedArrayBuilder[IR]() val newLet = new BoxedArrayBuilder[AggLet]() - val transformed = this.extract(aggIR, ab, newSeq, newLet, newMemo, result, r, isScan) + val transformed = this.extract(aggIR, env, bindingNodesReferenced, rewriteMap, ab, newSeq, newLet, newMemo, result, r, isScan) seqBuilder += If(cond, addLets(Begin(newSeq.result()), newLet.result()), Begin(FastIndexedSeq[IR]())) transformed @@ -433,7 +476,7 @@ object Extract { case AggExplode(array, name, aggBody, _) => val newSeq = new BoxedArrayBuilder[IR]() val newLet = new BoxedArrayBuilder[AggLet]() - val transformed = this.extract(aggBody, ab, newSeq, newLet, newMemo, result, r, isScan) + val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, ab, newSeq, newLet, newMemo, result, r, isScan) val (dependent, independent) = partitionDependentLets(newLet.result(), name) letBuilder ++= independent @@ -444,7 +487,7 @@ object Extract { val newAggs = new BoxedArrayBuilder[(InitOp, PhysicalAggSig)]() val newSeq = new BoxedArrayBuilder[IR]() val newRef = Ref(genUID(), null) - val transformed = this.extract(aggIR, newAggs, newSeq, letBuilder, newMemo, GetField(newRef, "value"), r, isScan) + val transformed = this.extract(aggIR, env, bindingNodesReferenced, rewriteMap, newAggs, newSeq, letBuilder, newMemo, GetField(newRef, "value"), r, isScan) val i = ab.length val (initOps, pAggSigs) = newAggs.result().unzip @@ -464,7 +507,7 @@ object Extract { val newSeq = new BoxedArrayBuilder[IR]() val newLet = new BoxedArrayBuilder[AggLet]() val newRef = Ref(genUID(), null) - val transformed = this.extract(aggBody, newAggs, newSeq, newLet, newMemo, newRef, r, isScan) + val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, newAggs, newSeq, newLet, newMemo, newRef, r, isScan) val (dependent, independent) = partitionDependentLets(newLet.result(), elementName) letBuilder ++= independent @@ -514,7 +557,20 @@ object Extract { case x: StreamAggScan => assert(!ContainsAgg(x)) x - case _ => MapIR(extract)(ir) + case x => + val newChildren = ir.children.zipWithIndex.map { case (child: IR, i) => + val nb = Bindings(x, i) + val newEnv = if (nb.nonEmpty) { + val re = RefEquality(x) + env.bindIterable(nb.map { case (name, _) => (name, re)}) + } else env + + this.extract(child, newEnv, bindingNodesReferenced, rewriteMap, ab, seqBuilder, letBuilder, memo, result, r, isScan) + } + Copy(x, newChildren) } + + rewriteMap.bind(ir, newNode) + newNode } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 5a85f3b07c6..b1326697ceb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -814,7 +814,11 @@ object LowerTableIR { val newKeyType = newKey.typ.asInstanceOf[TStruct] val resultUID = genUID() - val aggs@Aggs(postAggIR, init, seq, aggSigs) = Extract(expr, resultUID, analyses.requirednessAnalysis) + val aggs = Extract(expr, resultUID, analyses.requirednessAnalysis) + val postAggIR = aggs.postAggIR + val init = aggs.init + val seq = aggs.seqPerElt + val aggSigs = aggs.aggs val partiallyAggregated = loweredChild.mapPartition(Some(FastIndexedSeq())) { partition => Let("global", loweredChild.globals, diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala index af7aa11599b..4cd82228bfb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala @@ -94,37 +94,44 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { val context: String = "LowerArrayAggsToRunAggs" def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { - val r = Requiredness(ir, ctx) - RewriteBottomUp(ir, { + val x = ir.noSharing + val r = Requiredness(x, ctx) + RewriteBottomUp(x, { case x@StreamAgg(a, name, query) => val res = genUID() val aggs = Extract(query, res, r) - val newNode = Let( - res, - RunAgg( - Begin(FastSeq( - aggs.init, - StreamFor( - a, - name, - aggs.seqPerElt))), - aggs.results, - aggs.states), - aggs.postAggIR) + + val newNode = aggs.rewriteFromInitBindingRoot { root => + Let( + res, + RunAgg( + Begin(FastSeq( + aggs.init, + StreamFor( + a, + name, + aggs.seqPerElt))), + aggs.results, + aggs.states), + root) + } + if (newNode.typ != x.typ) - throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }") + throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") Some(newNode) case x@StreamAggScan(a, name, query) => val res = genUID() val aggs = Extract(query, res, r, isScan=true) - val newNode = RunAggScan( - a, - name, - aggs.init, - aggs.seqPerElt, - Let(res, aggs.results, aggs.postAggIR), - aggs.states - ) + val newNode = aggs.rewriteFromInitBindingRoot { root => + RunAggScan( + a, + name, + aggs.init, + aggs.seqPerElt, + Let(res, aggs.results, root), + aggs.states + ) + } if (newNode.typ != x.typ) throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }") Some(newNode) diff --git a/hail/src/main/scala/is/hail/io/index/IndexReadIterator.scala b/hail/src/main/scala/is/hail/io/index/IndexReadIterator.scala deleted file mode 100644 index 3d6f28e4fb7..00000000000 --- a/hail/src/main/scala/is/hail/io/index/IndexReadIterator.scala +++ /dev/null @@ -1,101 +0,0 @@ -package is.hail.io.index - -import java.io.InputStream - -import is.hail.asm4s.HailClassLoader -import is.hail.annotations.Region -import is.hail.io.Decoder -import is.hail.types.virtual.TStruct -import is.hail.utils.{ByteTrackingInputStream, Interval} -import org.apache.spark.ExposedMetrics -import org.apache.spark.executor.InputMetrics -import org.apache.spark.sql.Row - -class IndexReadIterator( - theHailClassLoader: HailClassLoader, - makeDec: (InputStream, HailClassLoader) => Decoder, - region: Region, - in: InputStream, - idxr: IndexReader, - offsetField: String, // can be null - bounds: Interval, - metrics: InputMetrics = null -) extends Iterator[Long] { - - private[this] val (startIdx, endIdx) = idxr.boundsByInterval(bounds) - private[this] var n = endIdx - startIdx - - private[this] val trackedIn = new ByteTrackingInputStream(in) - private[this] val field = Option(offsetField).map { f => - idxr.annotationType.asInstanceOf[TStruct].fieldIdx(f) - } - private[this] val dec = - try { - if (n > 0) { - val dec = makeDec(trackedIn, theHailClassLoader) - val i = idxr.queryByIndex(startIdx) - val off = field.map { j => - i.annotation.asInstanceOf[Row].getAs[Long](j) - }.getOrElse(i.recordOffset) - dec.seek(off) - dec - } else { - in.close() - null - } - } catch { - case e: Exception => - idxr.close() - in.close() - throw e - } - - private[this] var closed = false - - private var cont: Byte = if (dec != null) dec.readByte() else 0 - if (cont == 0) { - idxr.close() - if (dec != null) dec.close() - } - - def hasNext: Boolean = cont != 0 && n > 0 - - def next(): Long = _next() - - def _next(): Long = { - if (!hasNext) - throw new NoSuchElementException("next on empty iterator") - - n -= 1 - try { - val res = dec.readRegionValue(region) - cont = dec.readByte() - if (metrics != null) { - ExposedMetrics.incrementRecord(metrics) - ExposedMetrics.incrementBytes(metrics, trackedIn.bytesReadAndClear()) - } - - if (cont == 0) { - close() - } - - res - } catch { - case e: Exception => - close() - throw e - } - } - - def close(): Unit = { - if (!closed) { - idxr.close() - if (dec != null) dec.close() - closed = true - } - } - - override def finalize(): Unit = { - close() - } -} diff --git a/hail/src/main/scala/is/hail/io/index/IndexWriter.scala b/hail/src/main/scala/is/hail/io/index/IndexWriter.scala index 38ab80166bf..0e87a3bb3f6 100644 --- a/hail/src/main/scala/is/hail/io/index/IndexWriter.scala +++ b/hail/src/main/scala/is/hail/io/index/IndexWriter.scala @@ -51,6 +51,10 @@ case class IndexMetadataUntypedJSON( fileVersion, branchingFactor, height, keyType, annotationType, nKeys, indexPath, rootOffset, attributes) + + def toFileMetadata: VariableMetadata = VariableMetadata( + branchingFactor, height, nKeys, rootOffset, attributes + ) } case class IndexMetadata( diff --git a/hail/src/main/scala/is/hail/io/index/MaybeIndexedReadZippedIterator.scala b/hail/src/main/scala/is/hail/io/index/MaybeIndexedReadZippedIterator.scala deleted file mode 100644 index 5018189714d..00000000000 --- a/hail/src/main/scala/is/hail/io/index/MaybeIndexedReadZippedIterator.scala +++ /dev/null @@ -1,143 +0,0 @@ -package is.hail.io.index - -import java.io.InputStream - -import is.hail.annotations.Region -import is.hail.asm4s.AsmFunction3RegionLongLongLong -import is.hail.io.Decoder -import is.hail.types.virtual.TStruct -import is.hail.utils.{ByteTrackingInputStream, Interval} -import org.apache.spark.executor.InputMetrics -import org.apache.spark.sql.Row - -class MaybeIndexedReadZippedIterator( - mkRowsDec: (InputStream) => Decoder, - mkEntriesDec: (InputStream) => Decoder, - inserter: AsmFunction3RegionLongLongLong, - region: Region, - isRows: InputStream, - isEntries: InputStream, - idxr: IndexReader, - rowsOffsetField: String, - entriesOffsetField: String, - bounds: Interval, - metrics: InputMetrics = null -) extends Iterator[Long] { - - private[this] var closed: Boolean = false - - private[this] val startAndEnd = Option(idxr).map(_.boundsByInterval(bounds)) - private[this] val firstAnnotation = try { - startAndEnd.flatMap { case (start, end) => - if (end == start || idxr.nKeys == 0) None else Some(idxr.queryByIndex(start)) - } - } catch { - case e: Exception => - if (idxr != null) - idxr.close() - isRows.close() - isEntries.close() - throw e - } - - private[this] var n = startAndEnd.map(x => x._2 - x._1) - - private[this] val trackedRowsIn = new ByteTrackingInputStream(isRows) - private[this] val trackedEntriesIn = new ByteTrackingInputStream(isEntries) - - private[this] val rowsIdxField = Option(rowsOffsetField).map { f => idxr.annotationType.asInstanceOf[TStruct].fieldIdx(f) } - private[this] val entriesIdxField = Option(entriesOffsetField).map { f => idxr.annotationType.asInstanceOf[TStruct].fieldIdx(f) } - - private[this] val rows = try { - if (n.forall(_ > 0)) { - val dec = mkRowsDec(trackedRowsIn) - firstAnnotation.foreach { i => - val off = rowsIdxField.map { j => i.annotation.asInstanceOf[Row].getAs[Long](j) }.getOrElse(i.recordOffset) - dec.seek(off) - } - dec - } else { - isRows.close() - isEntries.close() - null - } - } catch { - case e: Exception => - if (idxr != null) - idxr.close() - isRows.close() - isEntries.close() - throw e - } - private[this] val entries = try { - if (rows == null) { - null - } else { - val dec = mkEntriesDec(trackedEntriesIn) - firstAnnotation.foreach { i => - val off = entriesIdxField.map { j => i.annotation.asInstanceOf[Row].getAs[Long](j) }.getOrElse(i.recordOffset) - dec.seek(off) - } - dec - } - } catch { - case e: Exception => - if (idxr != null) - idxr.close() - isRows.close() - isEntries.close() - throw e - } - - require(!((rows == null) ^ (entries == null))) - - private def nextCont(): Byte = { - val br = rows.readByte() - val be = entries.readByte() - assert(br == be) - br - } - - private var cont: Byte = if (rows != null) nextCont() else 0 - - def hasNext: Boolean = cont != 0 && n.forall(_ > 0) - - def next(): Long = _next() - - def _next(): Long = { - if (!hasNext) - throw new NoSuchElementException("next on empty iterator") - - n = n.map(_ - 1) - try { - val rowOff = rows.readRegionValue(region) - val entOff = entries.readRegionValue(region) - val off = inserter(region, rowOff, entOff) - cont = nextCont() - - if (cont == 0) { - close() - } - - off - } catch { - case e: Exception => - close() - throw e - } - } - - def close(): Unit = { - if (!closed) { - if (idxr != null) - idxr.close() - if (rows != null) rows.close() - if (entries != null) entries.close() - closed = true - } - } - - override def finalize(): Unit = { - close() - } -} diff --git a/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala b/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala new file mode 100644 index 00000000000..990ca5e4254 --- /dev/null +++ b/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala @@ -0,0 +1,275 @@ +package is.hail.io.index + +import is.hail.annotations._ +import is.hail.asm4s._ +import is.hail.expr.ir.functions.IntervalFunctions.compareStructWithPartitionIntervalEndpoint +import is.hail.expr.ir.{BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitSettable} +import is.hail.io.fs.FS +import is.hail.rvd.AbstractIndexSpec +import is.hail.types.physical.stypes.concrete.SStackStruct +import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SIntervalValue, primitive} +import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64} +import is.hail.types.physical.stypes.{EmitType, SSettable} +import is.hail.types.physical.{PCanonicalArray, PCanonicalBaseStruct} +import is.hail.types.virtual.{TBoolean, TInt64, TTuple} +import is.hail.utils._ + +import java.io.InputStream + +case class VariableMetadata( + branchingFactor: Int, + height: Int, + nKeys: Long, + rootOffset: Long, + attributes: Map[String, Any] +) + + +class StagedIndexReader(emb: EmitMethodBuilder[_], spec: AbstractIndexSpec) { + private[this] val cache: Settable[LongToRegionValueCache] = emb.genFieldThisRef[LongToRegionValueCache]("index_cache") + private[this] val metadata: Settable[VariableMetadata] = emb.genFieldThisRef[VariableMetadata]("index_file_metadata") + + private[this] val is: Settable[ByteTrackingInputStream] = emb.genFieldThisRef[ByteTrackingInputStream]("index_is") + + private[this] val leafPType = spec.leafCodec.encodedType.decodedPType(spec.leafCodec.encodedVirtualType) + private[this] val internalPType = spec.internalNodeCodec.encodedType.decodedPType(spec.internalNodeCodec.encodedVirtualType) + private[this] val leafDec = spec.leafCodec.encodedType.buildDecoder(spec.leafCodec.encodedVirtualType, emb.ecb) + private[this]val internalDec = spec.internalNodeCodec.encodedType.buildDecoder(spec.internalNodeCodec.encodedVirtualType, emb.ecb) + + private[this] val leafChildType = leafPType.asInstanceOf[PCanonicalBaseStruct].types(1).asInstanceOf[PCanonicalArray].elementType.sType + val leafChildEmitType = EmitType(leafChildType, false) + private[this] val queryType = SStackStruct(TTuple(TInt64, leafChildType.virtualType), FastIndexedSeq(EmitType(SInt64, true), leafChildEmitType)) + + def initialize(cb: EmitCodeBuilder, + indexPath: Value[String] + ): Unit = { + val fs = cb.emb.getFS + cb.assign(cache, Code.newInstance[LongToRegionValueCache, Int](16)) + cb.assign(metadata, Code.invokeScalaObject2[FS, String, IndexMetadataUntypedJSON]( + IndexReader.getClass, "readUntyped", fs, indexPath + ).invoke[VariableMetadata]("toFileMetadata")) + + // FIXME: hardcoded. Will break if we change spec -- assumption not introduced with this code, but propagated. + cb.assign(is, Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.open(indexPath.concat("/index"), false))) + + } + + def close(cb: EmitCodeBuilder): Unit = { + cb += is.invoke[Unit]("close") + cb += cache.invoke[Unit]("free") + cb.assign(is, Code._null) + cb.assign(cache, Code._null) + cb.assign(metadata, Code._null) + } + + /** + * returns tuple of (count, starting leaf) + * memory of starting leaf is not owned by `region`, consumers should deep copy if necessary + * starting leaf may be missing if the index is empty + */ + def queryInterval(cb: EmitCodeBuilder, + region: Value[Region], + interval: SIntervalValue): SBaseStructValue = { + + val n = cb.newLocal[Long]("n") + val startLeaf = cb.emb.newEmitLocal(leafChildEmitType) + + val start = interval.loadStart(cb).get(cb).asBaseStruct + val end = interval.loadEnd(cb).get(cb).asBaseStruct + val includesStart = interval.includesStart() + val includesEnd = interval.includesEnd() + + val startQuerySettable = cb.newSLocal(queryType, "startQuery") + cb.ifx(includesStart, + cb.assign(startQuerySettable, queryBound(cb, region, start, primitive(false), "lower")), + cb.assign(startQuerySettable, queryBound(cb, region, start, primitive(true), "upper")) + ) + + val endQuerySettable = cb.newSLocal(queryType, "endQuery") + cb.ifx(includesEnd, + cb.assign(endQuerySettable, queryBound(cb, region, end, primitive(true), "upper")), + cb.assign(endQuerySettable, queryBound(cb, region, end, primitive(false), "lower")) + ) + + cb.assign(n, + endQuerySettable.asBaseStruct.loadField(cb, 0).get(cb).asInt64.value - + startQuerySettable.asBaseStruct.loadField(cb, 0).get(cb).asInt64.value + ) + cb.assign(startLeaf, startQuerySettable.asBaseStruct.loadField(cb, 1)) + + cb.ifx(n < 0L, cb._fatal("n less than 0: ", n.toS, ", startQuery=", cb.strValue(startQuerySettable), ", endQuery=", cb.strValue(endQuerySettable))) + + + SStackStruct.constructFromArgs(cb, region, TTuple(TInt64, startLeaf.st.virtualType), EmitCode.present(cb.emb, primitive(n)), startLeaf) + } + + // internal node is an array of children + private[io] def readInternalNode(cb: EmitCodeBuilder, offset: Value[Long]): SBaseStructValue = { + val ret = cb.newSLocal(internalPType.sType, "internalNode") + + // returns an address if cached, or -1L if not found + val cached = cb.memoize(cache.invoke[Long, Long]("get", offset)) + + cb.ifx(cached cne -1L, { + cb.assign(ret, internalPType.loadCheapSCode(cb, cached)) + }, { + cb.assign(ret, cb.invokeSCode(cb.emb.ecb.getOrGenEmitMethod("readInternalNode", ("readInternalNode", this), FastIndexedSeq(LongInfo), ret.st.paramType) { emb => + emb.emitSCode { cb => + val offset = emb.getCodeParam[Long](1) + cb += is.invoke[Long, Unit]("seek", offset) + val ib = cb.memoize(spec.internalNodeCodec.buildCodeInputBuffer(is)) + cb.ifx(ib.readByte() cne 1, cb._fatal("bad buffer at internal!")) + val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]("getRegion", Region.TINIER)) + val internalNode = internalDec.apply(cb, region, ib) + val internalNodeAddr = internalPType.store(cb, region, internalNode, false) + cb += cache.invoke[Long, Region, Long, Unit]("put", offset, region, internalNodeAddr) + internalNode + } + }, offset)) + }) + + ret.asBaseStruct + } + + // leaf node is a struct + private[io] def readLeafNode(cb: EmitCodeBuilder, offset: Value[Long]): SBaseStructValue = { + val ret = cb.newSLocal(leafPType.sType, "leafNode") + + // returns an address if cached, or -1L if not found + val cached = cb.memoize(cache.invoke[Long, Long]("get", offset)) + + cb.ifx(cached cne -1L, { + cb.assign(ret, leafPType.loadCheapSCode(cb, cached)) + }, { + cb.assign(ret, cb.invokeSCode(cb.emb.ecb.getOrGenEmitMethod("readLeafNode", ("readLeafNode", this), FastIndexedSeq(LongInfo), ret.st.paramType) { emb => + emb.emitSCode { cb => + val offset = emb.getCodeParam[Long](1) + cb += is.invoke[Long, Unit]("seek", offset) + val ib = cb.memoize(spec.leafCodec.buildCodeInputBuffer(is)) + cb.ifx(ib.readByte() cne 0, cb._fatal("bad buffer at leaf!")) + val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]("getRegion", Region.TINIER)) + val leafNode = leafDec.apply(cb, region, ib) + val leafNodeAddr = leafPType.store(cb, region, leafNode, false) + cb += cache.invoke[Long, Region, Long, Unit]("put", offset, region, leafNodeAddr) + leafPType.loadCheapSCode(cb, leafNodeAddr) + } + }, offset)) + }) + ret.asBaseStruct + } + + // returns queryType + def queryBound(cb: EmitCodeBuilder, region: Value[Region], partitionBoundLeftEndpoint: SBaseStructValue, leansRight: SBooleanValue, boundType: String): SBaseStructValue = { + cb.invokeSCode( + cb.emb.ecb.getOrGenEmitMethod("lowerBound", + ("lowerBound", this, boundType), + FastIndexedSeq(typeInfo[Region], partitionBoundLeftEndpoint.st.paramType, leansRight.st.paramType), + queryType.paramType) { emb => + emb.emitSCode { cb => + val region = emb.getCodeParam[Region](1) + val endpoint = emb.getSCodeParam(2).asBaseStruct + val leansRight = emb.getSCodeParam(3).asBoolean + queryType.coerceOrCopy(cb, region, queryBound(cb, region, endpoint, leansRight, cb.memoize(metadata.invoke[Int]("height") - 1), cb.memoize(metadata.invoke[Long]("rootOffset")), boundType), false) } + }, region, partitionBoundLeftEndpoint, leansRight).asBaseStruct + } + + // partitionBoundEndpoint is a tuple(partitionBoundEndpoint, bool) + // returns a tuple of (index, LeafChild) + private def queryBound(cb: EmitCodeBuilder, + region: Value[Region], + endpoint: SBaseStructValue, + leansRight: SBooleanValue, + level: Value[Int], + offset: Value[Long], + boundType: String): SBaseStructValue = { + + val rInd: Settable[Long] = cb.newLocal[Long]("lowerBoundIndex") + val rLeafChild: EmitSettable = cb.emb.newEmitLocal(leafChildEmitType) + + val levelSettable = cb.newLocal[Int]("lowerBound_level") + val offsetSettable = cb.newLocal[Long]("lowerBound_offset") + + cb.assign(levelSettable,level) + cb.assign(offsetSettable,offset) + + val boundAndSignTuple = SStackStruct.constructFromArgs(cb, + region, + TTuple(endpoint.st.virtualType, TBoolean), + EmitCode.present(cb.emb, endpoint), + EmitCode.present(cb.emb, leansRight) + ) + + val Lstart = CodeLabel() + cb.define(Lstart) + + cb.ifx(levelSettable ceq 0, { + val node = readLeafNode(cb, offsetSettable).asBaseStruct + + /* + LeafNode( + firstIndex: Long, + children: IndexedSeq[LeafChild] + LeafChild( + key: Annotation, + recordOffset: Long, + annotation: Annotation) + */ + val children = node.asBaseStruct.loadField(cb, "keys").get(cb).asIndexable + + val idx = new BinarySearch(cb.emb, + children.st, + EmitType(boundAndSignTuple.st, true), + ((cb, elt) => cb.memoize(elt.get(cb).asBaseStruct.loadField(cb, "key"))), + bound=boundType, + ltF = { (cb, containerEltEV, partBoundEV) => + val containerElt = containerEltEV.get(cb).asBaseStruct + val partBound = partBoundEV.get(cb).asBaseStruct + val endpoint = partBound.loadField(cb, 0).get(cb).asBaseStruct + val leansRight = partBound.loadField(cb, 1).get(cb).asBoolean.value + val comp = compareStructWithPartitionIntervalEndpoint(cb, containerElt, endpoint, leansRight) + val ltOrGt = cb.memoize(if (boundType == "lower") comp < 0 else comp > 0) + ltOrGt + } + ) + .search(cb, children, EmitCode.present(cb.emb, boundAndSignTuple)) + + val firstIndex = node.asBaseStruct.loadField(cb, "first_idx").get(cb).asInt64.value.get + val updatedIndex = firstIndex + idx.toL + cb.assign(rInd, updatedIndex) + val idxWithModification = cb.memoize((if (boundType == "lower") idx else cb.memoize(idx-1)) min children.loadLength()) + val leafChild = children.loadElement(cb, idxWithModification).get(cb).asBaseStruct + cb.assign(rLeafChild, EmitCode.present(cb.emb, leafChild)) + }, { + val children = readInternalNode(cb, offsetSettable).loadField(cb, "children").get(cb).asIndexable + cb.ifx(children.loadLength() ceq 0, { + // empty interal node occurs if the indexed file contains no keys + cb.assign(rInd, 0L) + cb.assign(rLeafChild, EmitCode.missing(cb.emb, leafChildType)) + }, { + val idx = new BinarySearch(cb.emb, + children.st, + EmitType(boundAndSignTuple.st, true), + ((cb, elt) => cb.memoize(elt.get(cb).asBaseStruct.loadField(cb, "first_key"))), + bound=boundType, + ltF = { (cb, containerEltEV, partBoundEV) => + val containerElt = containerEltEV.get(cb).asBaseStruct + val partBound = partBoundEV.get(cb).asBaseStruct + val endpoint = partBound.loadField(cb, 0).get(cb).asBaseStruct + val leansRight = partBound.loadField(cb, 1).get(cb).asBoolean.value + val comp = compareStructWithPartitionIntervalEndpoint(cb, containerElt, endpoint, leansRight) + val ltOrGt = cb.memoize(if (boundType == "lower") comp < 0 else comp > 0) + ltOrGt + } + ) + .search(cb, children, EmitCode.present(cb.emb, boundAndSignTuple)) + cb.assign(levelSettable, levelSettable-1) + cb.assign(offsetSettable, children.loadElement(cb, (idx-1).max(0)).get(cb).asBaseStruct.loadField(cb, "index_file_offset").get(cb).asLong.value) + cb.goto(Lstart) + }) + }) + + SStackStruct.constructFromArgs(cb, region, queryType.virtualType, + EmitCode.present(cb.emb, primitive(rInd)), + rLeafChild) + } +} diff --git a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala index f62e9665ce7..ec6eda2b684 100644 --- a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala +++ b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala @@ -52,25 +52,6 @@ object AbstractRVDSpec { } } - def readLocal( - ctx: ExecuteContext, - path: String, - enc: AbstractTypedCodecSpec, - partFiles: Array[String], - requestedType: TStruct): (PStruct, Long) = { - assert(partFiles.length == 1) - val fs = ctx.fs - val r = ctx.r - - val (rType: PStruct, dec) = enc.buildDecoder(ctx, requestedType) - - val f = partPath(path, partFiles(0)) - using(fs.open(f)) { in => - val Array(rv) = HailContext.readRowsPartition(dec)(ctx.theHailClassLoader, r, in).toArray - (rType, rv) - } - } - def partPath(path: String, partFile: String): String = path + "/parts/" + partFile def writeSingle( @@ -201,61 +182,6 @@ object AbstractRVDSpec { } } } - - def readZipped( - ctx: ExecuteContext, - specLeft: AbstractRVDSpec, - specRight: AbstractRVDSpec, - pathLeft: String, - pathRight: String, - newPartitioner: Option[RVDPartitioner], - filterIntervals: Boolean, - requestedType: Type, - leftRType: TStruct, rightRType: TStruct, - requestedKey: IndexedSeq[String], - fieldInserter: (ExecuteContext, PStruct, PStruct) => (PStruct, (HailClassLoader, FS, Int, Region) => AsmFunction3RegionLongLongLong) - ): RVD = { - require(specRight.key.isEmpty) - val partitioner = specLeft.partitioner - - val extendedNewPartitioner = newPartitioner.map(_.extendKey(partitioner.kType)) - val (parts, tmpPartitioner) = extendedNewPartitioner match { - case Some(np) => - val tmpPart = np.intersect(partitioner) - assert(specLeft.key.nonEmpty) - val p = tmpPart.rangeBounds.map { b => specLeft.partFiles(partitioner.lowerBoundInterval(b)) } - (p, tmpPart) - case None => - // need to remove partitions with degenerate intervals - // these partitions are necessarily empty - val iOrd = partitioner.kord.intervalEndpointOrdering - val includedIndices = (0 until partitioner.numPartitions).filter { i => - val rb = partitioner.rangeBounds(i) - !rb.isDisjointFrom(iOrd, rb) - }.toArray - (includedIndices.map(specLeft.partFiles), partitioner.copy(rangeBounds = includedIndices.map(partitioner.rangeBounds))) - } - - val (isl, isr) = (specLeft, specRight) match { - case (l: Indexed, r: Indexed) => (Some(l.indexSpec), Some(r.indexSpec)) - case _ => (None, None) - } - - val (leftPType: PStruct, makeLeftDec) = specLeft.typedCodecSpec.buildDecoder(ctx, leftRType) - val (rightPType: PStruct, makeRightDec) = specRight.typedCodecSpec.buildDecoder(ctx, rightRType) - - val (t: PStruct, makeInserter) = fieldInserter(ctx, leftPType, rightPType) - assert(t.virtualType == requestedType) - val crdd = HailContext.readRowsSplit(ctx, - pathLeft, pathRight, isl, isr, - parts, tmpPartitioner.rangeBounds, - makeLeftDec, makeRightDec, makeInserter) - val tmprvd = RVD(RVDType(t, requestedKey), tmpPartitioner.coarsen(requestedKey.length), crdd) - extendedNewPartitioner match { - case Some(part) if !filterIntervals => tmprvd.repartition(ctx, part.coarsen(requestedKey.length)) - case _ => tmprvd - } - } } trait Indexed extends AbstractRVDSpec { @@ -279,22 +205,6 @@ abstract class AbstractRVDSpec { def attrs: Map[String, String] - def read( - ctx: ExecuteContext, - path: String, - requestedType: TStruct, - newPartitioner: Option[RVDPartitioner] = None, - filterIntervals: Boolean = false - ): RVD = newPartitioner match { - case Some(_) => fatal("attempted to read unindexed data as indexed") - case None => - val requestedKey = key.takeWhile(requestedType.hasField) - val (pType: PStruct, crdd) = HailContext.readRows(ctx, path, typedCodecSpec, partFiles, requestedType) - val rvdType = RVDType(pType, requestedKey) - - RVD(rvdType, partitioner.coarsen(requestedKey.length), crdd) - } - def readTableStage( ctx: ExecuteContext, path: String, @@ -324,9 +234,6 @@ abstract class AbstractRVDSpec { body) } - def readLocalSingleRow(ctx: ExecuteContext, path: String, requestedType: TStruct): (PStruct, Long) = - AbstractRVDSpec.readLocal(ctx, path, typedCodecSpec, partFiles, requestedType) - def write(fs: FS, path: String) { using(fs.create(path + "/metadata.json.gz")) { out => import AbstractRVDSpec.formats @@ -497,36 +404,6 @@ case class IndexedRVDSpec2(_key: IndexedSeq[String], val attrs: Map[String, String] = _attrs - override def read( - ctx: ExecuteContext, - path: String, - requestedType: TStruct, - newPartitioner: Option[RVDPartitioner] = None, - filterIntervals: Boolean = false - ): RVD = { - newPartitioner match { - case Some(np) => - val extendedNP = np.extendKey(partitioner.kType) - val requestedKey = key.takeWhile(requestedType.hasField) - val tmpPartitioner = partitioner.intersect(extendedNP) - - assert(key.nonEmpty) - val parts = tmpPartitioner.rangeBounds.map { b => partFiles(partitioner.lowerBoundInterval(b)) } - - val (decPType: PStruct, crdd) = HailContext.readIndexedRows(ctx, path, _indexSpec, typedCodecSpec, parts, tmpPartitioner.rangeBounds, requestedType) - val rvdType = RVDType(decPType, requestedKey) - val tmprvd = RVD(rvdType, tmpPartitioner.coarsen(requestedKey.length), crdd) - - if (filterIntervals) - tmprvd - else - tmprvd.repartition(ctx, extendedNP.coarsen(requestedKey.length)) - case None => - // indexed reads are costly; don't use an indexed read when possible - super.read(ctx, path, requestedType, None, filterIntervals) - } - } - override def readTableStage( ctx: ExecuteContext, path: String, diff --git a/hail/src/main/scala/is/hail/services/package.scala b/hail/src/main/scala/is/hail/services/package.scala index 07d7683ca42..6b0cb342f1e 100644 --- a/hail/src/main/scala/is/hail/services/package.scala +++ b/hail/src/main/scala/is/hail/services/package.scala @@ -41,11 +41,11 @@ package object services { // true error. val e = reactor.core.Exceptions.unwrap(_e) e match { + case e: HttpResponseException => + e.getStatusCode() == 400 && e.getMessage.contains("Invalid grant: account not found") case e @ (_: SSLException | _: StorageException | _: IOException) => val cause = e.getCause cause != null && isRetryOnceError(cause) - case e: HttpResponseException => - e.getStatusCode() == 400 && e.getMessage.contains("Invalid grant: account not found") case _ => false } diff --git a/hail/src/main/scala/is/hail/utils/Cache.scala b/hail/src/main/scala/is/hail/utils/Cache.scala index daaf55887fa..007523bc7fa 100644 --- a/hail/src/main/scala/is/hail/utils/Cache.scala +++ b/hail/src/main/scala/is/hail/utils/Cache.scala @@ -1,5 +1,7 @@ package is.hail.utils +import is.hail.annotations.{Region, RegionMemory} + import java.util import java.util.Map.Entry @@ -14,3 +16,36 @@ class Cache[K, V](capacity: Int) { def size: Int = synchronized { m.size() } } + +class LongToRegionValueCache(capacity: Int) { + private[this] val m = new util.LinkedHashMap[Long, (RegionMemory, Long)](capacity, 0.75f, true) { + override def removeEldestEntry(eldest: Entry[Long, (RegionMemory, Long)]): Boolean = { + val b = (size() > capacity) + if (b) { + val (rm, _) = eldest.getValue + rm.release() + } + b + } + } + + // the cache takes ownership of the region passed in + def put(key: Long, region: Region, addr: Long): Unit = { + val rm = region.getMemory() + m.put(key, (rm, addr)) + } + + // returns -1 if not in cache + def get(key: Long): Long = { + val v = m.get(key) + if (v == null) + -1L + else + v._2 + } + + def free(): Unit = { + m.forEach((k, v) => v._1.release()) + m.clear() + } +} \ No newline at end of file diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index 8fb9d857a88..d88866cc478 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -209,6 +209,17 @@ class IRSuite extends HailSuite { ) } + @Test def testApplyUnaryPrimOpBitCount() { + assertAllEvalTo( + (ApplyUnaryPrimOp(BitCount(), I32(0xdeadbeef)), Integer.bitCount(0xdeadbeef)), + (ApplyUnaryPrimOp(BitCount(), I32(-0xdeadbeef)), Integer.bitCount(-0xdeadbeef)), + (ApplyUnaryPrimOp(BitCount(), i32na), null), + (ApplyUnaryPrimOp(BitCount(), I64(0xdeadbeef12345678L)), java.lang.Long.bitCount(0xdeadbeef12345678L)), + (ApplyUnaryPrimOp(BitCount(), I64(-0xdeadbeef12345678L)), java.lang.Long.bitCount(-0xdeadbeef12345678L)), + (ApplyUnaryPrimOp(BitCount(), i64na), null) + ) + } + @Test def testApplyBinaryPrimOpAdd() { def assertSumsTo(t: Type, x: Any, y: Any, sum: Any) { assertEvalsTo(ApplyBinaryPrimOp(Add(), In(0, t), In(1, t)), FastIndexedSeq(x -> t, y -> t), sum) diff --git a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala index fd595c739b4..59aa85a60b9 100644 --- a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala @@ -10,6 +10,7 @@ import is.hail.expr.ir.orderings.CodeOrdering import is.hail.rvd.RVDType import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType +import is.hail.types.physical.stypes.interfaces.SBaseStructValue import is.hail.types.virtual._ import is.hail.utils._ import org.apache.spark.sql.Row @@ -458,9 +459,11 @@ class OrderingSuite extends HailSuite { val cset = fb.getCodeParam[Long](2) val cetuple = fb.getCodeParam[Long](3) - val bs = new BinarySearch(fb.apply_method, pset.sType, EmitType(pset.elementType.sType, true), keyOnly = false) + val bs = new BinarySearch(fb.apply_method, pset.sType, EmitType(pset.elementType.sType, true), { + (cb, elt) => elt + }) fb.emitWithBuilder(cb => - bs.lowerBound(cb, pset.loadCheapSCode(cb, cset), + bs.search(cb, pset.loadCheapSCode(cb, cset), EmitCode.fromI(fb.apply_method)(cb => IEmitCode.present(cb, pt.loadCheapSCode(cb, pTuple.loadField(cetuple, 0)))))) val asArray = SafeIndexedSeq(pArray, soff) @@ -493,10 +496,15 @@ class OrderingSuite extends HailSuite { val cdict = fb.getCodeParam[Long](2) val cktuple = fb.getCodeParam[Long](3) - val bs = new BinarySearch(fb.apply_method, pDict.sType, EmitType(pDict.keyType.sType, false), keyOnly = true) + val bs = new BinarySearch(fb.apply_method, pDict.sType, EmitType(pDict.keyType.sType, false), { (cb, elt) => + cb.memoize(elt.toI(cb).flatMap(cb) { + case x: SBaseStructValue => + x.loadField(cb, 0) + }) + }) fb.emitWithBuilder(cb => - bs.lowerBound(cb, pDict.loadCheapSCode(cb, cdict), + bs.search(cb, pDict.loadCheapSCode(cb, cdict), EmitCode.fromI(fb.apply_method)(cb => IEmitCode.present(cb, pDict.keyType.loadCheapSCode(cb, ptuple.loadField(cktuple, 0)))))) val asArray = SafeIndexedSeq(PCanonicalArray(pDict.elementType), soff)